1 Star 1 Fork 0

Zor-X-L/cuda-matmul-bench

Create your Gitee Account
Explore and code with more than 14 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
matmul_fp4.cpp 21.64 KB
Copy Edit Raw Blame History
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
/******************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* NVIDIA CORPORATION and its licensors retain all intellectual property
* and proprietary rights in and to this software, related documentation
* and any modifications thereto. Any use, reproduction, disclosure or
* distribution of this software and related documentation without an express
* license agreement from NVIDIA CORPORATION is strictly prohibited.
******************************************************************************/
#include <cusparseLt.h> // cusparseLt header
#include <cuda_runtime_api.h> // cudaMalloc, cudaMemcpy, etc.
#include <cuda_fp8.h>
#include <cuda_fp4.h>
#include <omp.h>
#include <cstdio> // printf
#include <cstdlib> // std::rand
#include <vector>
#include <random>
#define FP16 1000 // SM8.0
#define INT8 1001 // SM8.0
#define FP8 1002 // SM9.0
#define FP4 1003 // SM10.0
/*
* Choose your data type for matrices A and B
*/
#ifndef AB_TYPE
//#define AB_TYPE FP16
//#define AB_TYPE INT8
//#define AB_TYPE FP8
#define AB_TYPE FP4
#endif
#if AB_TYPE == FP16
using AB_t = __half;
using C_t = __half; // can be __half, __nv_bfloat16, float
using COMPUTE_t = float;
#elif AB_TYPE == INT8
using AB_t = int8_t;
using C_t = __half; // can be __half, __nv_bfloat16, int8_t, int
using COMPUTE_t = int;
#elif AB_TYPE == FP8
using AB_t = __nv_fp8_e4m3;
using C_t = __half; // can be __half, __nv_bfloat16, float
using COMPUTE_t = float;
#elif AB_TYPE == FP4
using AB_t = __nv_fp4x2_e2m1;
using ABSCALE_t = __nv_fp8_e4m3;
using C_t = __half; // can be __half, __nv_bfloat16, float
using COMPUTE_t = float;
#endif
template <typename value_t>
struct cuda_type { };
template <>
struct cuda_type <__half> {
static constexpr cudaDataType value = CUDA_R_16F;
};
template <>
struct cuda_type <__nv_bfloat16> {
static constexpr cudaDataType value = CUDA_R_16BF;
};
template <>
struct cuda_type <__nv_fp8_e4m3> {
static constexpr cudaDataType value = CUDA_R_8F_E4M3;
};
template <>
struct cuda_type <__nv_fp4x2_e2m1> {
static constexpr cudaDataType value = CUDA_R_4F_E2M1;
};
template <>
struct cuda_type <int8_t> {
static constexpr cudaDataType value = CUDA_R_8I;
};
template <>
struct cuda_type <int> {
static constexpr cudaDataType value = CUDA_R_32I;
};
template <typename value_t>
struct cusparse_compute_type { };
template <>
struct cusparse_compute_type<float> {
static constexpr cusparseComputeType value = CUSPARSE_COMPUTE_32F;
};
template <>
struct cusparse_compute_type<int> {
static constexpr cusparseComputeType value = CUSPARSE_COMPUTE_32I;
};
#define CHECK_CUDA(func) \
{ \
cudaError_t status = (func); \
if (status != cudaSuccess) { \
printf("CUDA API failed at line %d with error: %s (%d)\n", \
__LINE__, cudaGetErrorString(status), status); \
std::exit(EXIT_FAILURE); \
} \
}
#define CHECK_CUSPARSE(func) \
{ \
cusparseStatus_t status = (func); \
if (status != CUSPARSE_STATUS_SUCCESS) { \
printf("CUSPARSE API failed at line %d with error: %s (%d)\n", \
__LINE__, cusparseGetErrorString(status), status); \
std::exit(EXIT_FAILURE); \
} \
}
constexpr int EXIT_UNSUPPORTED = 2;
void run(const int dim_index, const size_t dim, const int repeat);
int pack_fp4_to_fp4x2(__nv_fp4_e2m1* in, size_t in_elements, __nv_fp4x2_e2m1* out, size_t out_elements);
int unpack_fp4x2_to_fp4(__nv_fp4x2_e2m1* in, size_t in_elements, __nv_fp4_e2m1* out, size_t out_elements);
void test_fp4_fp4x2_conversion();
int main(int argc, char* argv[]) {
int major_cc, minor_cc;
CHECK_CUDA(cudaDeviceGetAttribute(&major_cc, cudaDevAttrComputeCapabilityMajor, 0));
CHECK_CUDA(cudaDeviceGetAttribute(&minor_cc, cudaDevAttrComputeCapabilityMinor, 0));
if (!(major_cc == 8 && minor_cc == 0) &&
!(major_cc == 8 && minor_cc == 6) &&
!(major_cc == 8 && minor_cc == 7) &&
!(major_cc == 8 && minor_cc == 9) &&
!(major_cc == 9 && minor_cc == 0) &&
!(major_cc == 10 && minor_cc == 0) &&
!(major_cc == 12 && minor_cc == 0)) {
std::printf("\ncusparseLt is supported only on GPU devices with"
" compute capability == 8.0, 8.6, 8.7, 8.9, 9.0 10.0 12.0 current: %d.%d\n\n",
major_cc, minor_cc);
return EXIT_UNSUPPORTED;
}
const size_t dims[] = {
128, 256, 384, 512, 768, 1024, 1408, 2048, 2944, 4096, 5760, 8192, 11648, 16384, 23168, 32768, 46336 };
std::printf("Usage: %s [dim_index_start] [dim_index_end] [repeat]\n", argv[0]);
std::printf(" default values: dim_index_start=0, dim_index_end=16, repeat=1\n");
int dim_index_start = 0;
int dim_index_end = 16;
int repeat = 1;
if (argc - 1 >= 1) {
dim_index_start = std::atoi(argv[1]);
}
if (argc - 1 >= 2) {
dim_index_end = std::atoi(argv[2]);
}
if (argc - 1 >= 3) {
repeat = std::atoi(argv[3]);
}
for (int i = dim_index_start; i <= dim_index_end; ++i) {
run(i, dims[i], repeat);
}
return EXIT_SUCCESS;
}
void run(const int dim_index, const size_t dim, const int repeat) {
// Host problem definition, row-major order
// bigger sizes may require dynamic allocations
const size_t m = dim;
const size_t n = m;
const size_t k = m;
auto order = CUSPARSE_ORDER_ROW;
auto opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto opB = CUSPARSE_OPERATION_TRANSPOSE;
auto type_AB = cuda_type<AB_t>::value;
auto type_C = cuda_type<C_t>::value;
auto compute_type = cusparse_compute_type<COMPUTE_t>::value;
bool matmul_search = true;
bool is_rowmajor = (order == CUSPARSE_ORDER_ROW);
bool isA_transposed = (opA != CUSPARSE_OPERATION_NON_TRANSPOSE);
bool isB_transposed = (opB != CUSPARSE_OPERATION_NON_TRANSPOSE);
auto num_A_rows = (isA_transposed) ? k : m;
auto num_A_cols = (isA_transposed) ? m : k;
auto num_B_rows = (isB_transposed) ? n : k;
auto num_B_cols = (isB_transposed) ? k : n;
auto num_C_rows = m;
auto num_C_cols = n;
unsigned alignment = 16;
auto lda = (is_rowmajor) ? num_A_cols : num_A_rows;
auto ldb = (is_rowmajor) ? num_B_cols : num_B_rows;
auto ldc = (is_rowmajor) ? num_C_cols : num_C_rows;
auto A_height = (is_rowmajor) ? num_A_rows : num_A_cols;
auto B_height = (is_rowmajor) ? num_B_rows : num_B_cols;
auto C_height = (is_rowmajor) ? num_C_rows : num_C_cols;
auto A_size = A_height * lda / 2 * sizeof(AB_t);
auto B_size = B_height * ldb / 2 * sizeof(AB_t);
auto C_size = C_height * ldc * sizeof(C_t);
auto AB_scale_block = 32;
auto A_scale_size = A_height * lda / AB_scale_block * sizeof(ABSCALE_t);
auto B_scale_size = B_height * ldb / AB_scale_block * sizeof(ABSCALE_t);
auto hA = new AB_t[A_size / sizeof(AB_t)];
auto hB = new AB_t[B_size / sizeof(AB_t)];
auto hC = new C_t[C_size / sizeof(C_t)];
auto hA_scale = new ABSCALE_t[A_scale_size / sizeof(ABSCALE_t)];
auto hB_scale = new ABSCALE_t[B_scale_size / sizeof(ABSCALE_t)];
{
for (int i = 0; i < A_scale_size / sizeof(ABSCALE_t); ++i) {
hA_scale[i] = static_cast<ABSCALE_t>(1);
}
for (int i = 0; i < B_scale_size / sizeof(ABSCALE_t); ++i) {
hB_scale[i] = static_cast<ABSCALE_t>(1);
}
int num_threads = omp_get_max_threads();
std::vector<std::mt19937_64> randgens(num_threads);
std::random_device rd;
for (auto& randgen : randgens) {
randgen.seed(rd());
}
#pragma omp parallel
{
int tid = omp_get_thread_num();
auto& randgen = randgens[tid];
std::uniform_int_distribution<int> dist(-2, 2);
#pragma omp for
for (int i = 0; i < m * k / 2; i++) {
float2 f;
if (i % 2 == 0) {
f.x = dist(randgen);
f.y = dist(randgen);
} else {
f.x = 0;//dist(randgen);
f.y = 0;//dist(randgen);
}
hA[i] = __nv_fp4x2_e2m1(f);
//std::printf("hA[%d]=%f\n", i, static_cast<float>(hA[i]));
}
#pragma omp for
for (int i = 0; i < k * n / 2; i++) {
float2 f;
f.x = dist(randgen);
f.y = dist(randgen);
hB[i] = __nv_fp4x2_e2m1(f);
}
#pragma omp for
for (int i = 0; i < m * n; i++) {
hC[i] = static_cast<C_t>(dist(randgen));
}
}
}
float alpha = 1.0f;
float beta = 1.0f;
//--------------------------------------------------------------------------
// Device memory management
AB_t* dA, * dB, * dA_compressed;
C_t* dC, * dD;
ABSCALE_t* dA_scale, * dB_scale;
int* d_valid;
CHECK_CUDA(cudaMalloc((void**)&dA, A_size));
CHECK_CUDA(cudaMalloc((void**)&dB, B_size));
CHECK_CUDA(cudaMalloc((void**)&dC, C_size));
CHECK_CUDA(cudaMalloc((void**)&dA_scale, A_scale_size));
CHECK_CUDA(cudaMalloc((void**)&dB_scale, B_scale_size));
CHECK_CUDA(cudaMalloc((void**)&d_valid, sizeof(int)));
dD = dC;
CHECK_CUDA(cudaMemcpy(dA, hA, A_size, cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(dB, hB, B_size, cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(dC, hC, C_size, cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(dA_scale, hA_scale, A_scale_size, cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(dB_scale, hB_scale, B_scale_size, cudaMemcpyHostToDevice));
//--------------------------------------------------------------------------
cusparseLtHandle_t handle;
cusparseLtMatDescriptor_t matA, matB, matC;
cusparseLtMatmulDescriptor_t matmul;
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan;
cusparseLtMatmulMatrixScale_t AB_scale_mode = CUSPARSELT_MATMUL_MATRIX_SCALE_VEC32_UE4M3;
cudaStream_t stream = 0;
CHECK_CUSPARSE(cusparseLtInit(&handle));
// matrix descriptor initialization
CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
&handle, &matA, num_A_rows,
num_A_cols, lda, alignment,
type_AB, order,
CUSPARSELT_SPARSITY_50_PERCENT));
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(
&handle, &matB, num_B_rows,
num_B_cols, ldb, alignment,
type_AB, order));
CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(
&handle, &matC, num_C_rows,
num_C_cols, ldc, alignment,
type_C, order));
// matmul, algorithm selection, and plan initialization
CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
&handle, &matmul, opA, opB,
&matA, &matB, &matC, &matC,
compute_type));
CHECK_CUSPARSE(cusparseLtMatmulAlgSelectionInit(
&handle, &alg_sel, &matmul,
CUSPARSELT_MATMUL_ALG_DEFAULT));
CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel));
CHECK_CUSPARSE(cusparseLtMatmulDescSetAttribute(
&handle,
&matmul,
CUSPARSELT_MATMUL_SPARSE_MAT_POINTER,
&dA,
sizeof(dA)));
CHECK_CUSPARSE(cusparseLtMatmulDescSetAttribute(
&handle,
&matmul,
CUSPARSELT_MATMUL_A_SCALE_MODE,
&AB_scale_mode,
sizeof(AB_scale_mode)));
CHECK_CUSPARSE(cusparseLtMatmulDescSetAttribute(
&handle,
&matmul,
CUSPARSELT_MATMUL_A_SCALE_POINTER,
&dA_scale,
sizeof(dA_scale)));
CHECK_CUSPARSE(cusparseLtMatmulDescSetAttribute(
&handle,
&matmul,
CUSPARSELT_MATMUL_B_SCALE_MODE,
&AB_scale_mode,
sizeof(AB_scale_mode)));
CHECK_CUSPARSE(cusparseLtMatmulDescSetAttribute(
&handle,
&matmul,
CUSPARSELT_MATMUL_B_SCALE_POINTER,
&dB_scale,
sizeof(dB_scale)));
//--------------------------------------------------------------------------
// Prune the A matrix (in-place) and check the correctness
//CHECK_CUSPARSE(cusparseLtSpMMAPrune(&handle, &matmul, dA, dA, CUSPARSELT_PRUNE_SPMMA_TILE, stream));
CHECK_CUSPARSE(cusparseLtSpMMAPruneCheck(&handle, &matmul, dA, d_valid, stream));
int is_valid;
CHECK_CUDA(cudaMemcpyAsync(&is_valid, d_valid, sizeof(int), cudaMemcpyDeviceToHost, stream));
CHECK_CUDA(cudaStreamSynchronize(stream));
if (is_valid != 0) {
std::printf("!!!! The matrix has been pruned in a wrong way. "
"cusparseLtMatmul will not provide correct results\n");
std::exit(EXIT_FAILURE);
}
//--------------------------------------------------------------------------
// Compress the A matrix
size_t compressed_size, compressed_buffer_size;
void* dA_compressedBuffer;
CHECK_CUSPARSE(cusparseLtSpMMACompressedSize(&handle, &plan, &compressed_size, &compressed_buffer_size));
std::printf("A_size=%zu, compressed_size=%zu, compressed_buffer_size=%zu\n", A_size, compressed_size, compressed_buffer_size);
CHECK_CUDA(cudaMalloc((void**)&dA_compressed, compressed_size));
CHECK_CUDA(cudaMalloc((void**)&dA_compressedBuffer, compressed_buffer_size));
CHECK_CUSPARSE(cusparseLtSpMMACompress(&handle, &plan, dA, dA_compressed, dA_compressedBuffer, stream));
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
size_t workspace_size;
CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size));
void* d_workspace;
CHECK_CUDA(cudaMalloc((void**)&d_workspace, workspace_size));
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Search the best kernel
if (matmul_search) {
CHECK_CUSPARSE(cusparseLtMatmulSearch(&handle, &plan, &alpha,
dA_compressed, dB, &beta,
dC, dD, d_workspace,
&stream, 1));
// dC accumulates so reset dC for correctness check
CHECK_CUDA(cudaMemcpy(dC, hC, C_size, cudaMemcpyHostToDevice));
}
// Perform the matrix multiplication
cudaEvent_t start, stop;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
for (int i = 0; i < repeat; ++i) {
CHECK_CUSPARSE(cusparseLtMatmul(&handle, &plan, &alpha, dA_compressed, dB, &beta, dC, dD, d_workspace, &stream, 1));
}
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
float milliseconds = 0;
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
std::printf("dim_index=%d, m=%zu, n=%zu, k=%zu, repeat=%d, milliseconds=%f, tflops=%f\n",
dim_index, m, n, k, repeat, milliseconds, static_cast<float>(m) * n * k * 2 * repeat / milliseconds / 1e9);
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// destroy plan and handle
CHECK_CUSPARSE(cusparseLtMatDescriptorDestroy(&matA));
CHECK_CUSPARSE(cusparseLtMatDescriptorDestroy(&matB));
CHECK_CUSPARSE(cusparseLtMatDescriptorDestroy(&matC));
CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan));
CHECK_CUSPARSE(cusparseLtDestroy(&handle));
//--------------------------------------------------------------------------
// device result check
// matrix A has been pruned
if (repeat == 1 && static_cast<float>(m) * n * k <= 512.0*512.0*512.0+0.5)
{
//CHECK_CUDA(cudaMemcpy(hA, dA, A_size, cudaMemcpyDeviceToHost));
__nv_fp4_e2m1* hA_unpacked = new __nv_fp4_e2m1[A_size * 2];
__nv_fp4_e2m1* hB_unpacked = new __nv_fp4_e2m1[B_size * 2];
unpack_fp4x2_to_fp4(hA, A_size, hA_unpacked, A_size * 2);
unpack_fp4x2_to_fp4(hB, B_size, hB_unpacked, B_size * 2);
/*
std::printf("%f %f %f %f %f %f %f %f\n",
static_cast<float>(hA_unpacked[0]),
static_cast<float>(hA_unpacked[1]),
static_cast<float>(hA_unpacked[2]),
static_cast<float>(hA_unpacked[3]),
static_cast<float>(hA_unpacked[4]),
static_cast<float>(hA_unpacked[5]),
static_cast<float>(hA_unpacked[6]),
static_cast<float>(hA_unpacked[7]));
std::printf("%f %f %f %f %f %f %f %f\n",
static_cast<float>(hB_unpacked[0]),
static_cast<float>(hB_unpacked[1]),
static_cast<float>(hB_unpacked[2]),
static_cast<float>(hB_unpacked[3]),
static_cast<float>(hB_unpacked[4]),
static_cast<float>(hB_unpacked[5]),
static_cast<float>(hB_unpacked[6]),
static_cast<float>(hB_unpacked[7]));
*/
bool A_std_layout = (is_rowmajor != isA_transposed);
bool B_std_layout = (is_rowmajor != isB_transposed);
// host computation
C_t* hC_result = new C_t[C_height * ldc];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
COMPUTE_t sum = static_cast<COMPUTE_t>(0);
for (int k1 = 0; k1 < k; k1++) {
auto posA = (A_std_layout) ? i * lda + k1 : i + k1 * lda;
auto posB = (B_std_layout) ? k1 * ldb + j : k1 + j * ldb;
sum += static_cast<COMPUTE_t>(hA_unpacked[posA]) * // [i][k]
static_cast<COMPUTE_t>(hB_unpacked[posB]); // [k][j]
}
auto posC = (is_rowmajor) ? i * ldc + j : i + j * ldc;
hC_result[posC] = static_cast<C_t>(alpha * sum + beta * static_cast<float>(hC[posC])); // [i][j]
}
}
// reuse hC for device results
CHECK_CUDA(cudaMemcpy(hC, dD, C_size, cudaMemcpyDeviceToHost))
// host-device comparison
int correct = 1;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
auto pos = (is_rowmajor) ? i * ldc + j : i + j * ldc;
auto device_value = hC[pos];
auto host_value = hC_result[pos];
if (device_value != host_value) {
// direct floating point comparison is not reliable
std::printf("(%d, %d):\t%3.0f vs. %3.0f\n",
i, j, static_cast<float>(host_value), static_cast<float>(device_value));
correct = 0;
break;
}
}
}
if (correct) {
std::printf("matmul_example test PASSED\n");
}
else {
std::printf("matmul_example test FAILED: wrong result\n");
}
delete[] hC_result;
delete[] hA_unpacked;
delete[] hB_unpacked;
}
//--------------------------------------------------------------------------
// host memory deallocation
delete[] hA;
delete[] hB;
delete[] hC;
delete[] hA_scale;
delete[] hB_scale;
//--------------------------------------------------------------------------
// device memory deallocation
CHECK_CUDA(cudaFree(dA_compressed));
CHECK_CUDA(cudaFree(dA));
CHECK_CUDA(cudaFree(dB));
CHECK_CUDA(cudaFree(dC));
CHECK_CUDA(cudaFree(dA_scale));
CHECK_CUDA(cudaFree(dB_scale));
CHECK_CUDA(cudaFree(d_valid));
CHECK_CUDA(cudaFree(d_workspace));
CHECK_CUDA(cudaFree(dA_compressedBuffer));
}
int pack_fp4_to_fp4x2(__nv_fp4_e2m1* in, size_t in_elements, __nv_fp4x2_e2m1* out, size_t out_elements) {
size_t i;
size_t j;
for (i = 0, j = 0; i < in_elements && j < out_elements; i += 2, j += 1) {
float2 f = float2();
f.x = static_cast<float>(in[i]);
f.y = static_cast<float>(in[i + 1]);
out[j] = __nv_fp4x2_e2m1(f);
}
return j;
}
int unpack_fp4x2_to_fp4(__nv_fp4x2_e2m1* in, size_t in_elements, __nv_fp4_e2m1* out, size_t out_elements) {
size_t i;
size_t j;
for (i = 0, j = 0; i < in_elements && j < out_elements; i += 1, j += 2) {
float2 f = static_cast<float2>(in[i]);
out[j] = static_cast<__nv_fp4_e2m1>(f.x);
out[j + 1] = static_cast<__nv_fp4_e2m1>(f.y);
}
return j;
}
void test_fp4_fp4x2_conversion() {
__nv_fp4_e2m1 a[6] = { static_cast<__nv_fp4_e2m1>(1.0), static_cast<__nv_fp4_e2m1>(2.0), static_cast<__nv_fp4_e2m1>(3.0), static_cast<__nv_fp4_e2m1>(4.0), static_cast<__nv_fp4_e2m1>(3.0), static_cast<__nv_fp4_e2m1>(2.0) };
__nv_fp4x2_e2m1 b[3];
__nv_fp4_e2m1 c[6];
pack_fp4_to_fp4x2(a, 6, b, 3);
unpack_fp4x2_to_fp4(b, 3, c, 6);
std::printf("%f %f %f %f %f %f\n", static_cast<float>(a[0]), static_cast<float>(a[1]), static_cast<float>(a[2]), static_cast<float>(a[3]), static_cast<float>(a[4]), static_cast<float>(a[5]));
std::printf("%hhx %hhx %hhx\n", reinterpret_cast<unsigned char*>(b)[0], reinterpret_cast<unsigned char*>(b)[1], reinterpret_cast<unsigned char*>(b)[2]);
std::printf("%f %f %f %f %f %f\n", static_cast<float>(c[0]), static_cast<float>(c[1]), static_cast<float>(c[2]), static_cast<float>(c[3]), static_cast<float>(c[4]), static_cast<float>(c[5]));
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C++
1
https://gitee.com/Zor-X-L/cuda-matmul-bench.git
git@gitee.com:Zor-X-L/cuda-matmul-bench.git
Zor-X-L
cuda-matmul-bench
cuda-matmul-bench
198f72bd16c780c2a2f953f896af6f24ffe93ba2

Search