# flash-kmeans **Repository Path**: gitstr/flash-kmeans ## Basic Information - **Project Name**: flash-kmeans - **Description**: No description available - **Primary Language**: Unknown - **License**: Apache-2.0 - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-06-20 - **Last Updated**: 2026-06-20 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Flash-KMeans
IO-aware batched K-Means clustering implemented with Triton GPU kernels. This repository provides the official K-Means implementation of [Sparse VideoGen2](https://arxiv.org/pdf/2505.18875).  ## Installation Install flash-kmeans with `pip`: ```bash pip install flash-kmeans ``` From source: ```bash git clone https://github.com/svg-project/flash-kmeans.git cd flash-kmeans pip install -e . ``` ## Usage ```python import torch from flash_kmeans import batch_kmeans_Euclid x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16) cluster_ids, centers, _ = batch_kmeans_Euclid(x, n_clusters=1000, tol=1e-4, verbose=True) ``` We also provide a API interface similar to `faiss/sklearn`, see [API docs](https://github.com/svg-project/flash-kmeans/blob/main/flash_kmeans/interface.py) for details. ## Benchmark We compare the performance of our Triton implementation with the following baselines: - [fast_pytorch_kmeans](https://github.com/DeMoriarty/fast_pytorch_kmeans) a Pytorch implmentation of K-Means clustering. - [fastkmeans(triton) / fastkmeans(torch)](https://github.com/AnswerDotAI/fastkmeans) another triton implementation of K-Means clustering. (and its Pytorch fallback) - flash-kmeans(triton) / flash-kmeans(torch): our implementation in Triton and Pytorch fallback. - batched torch kmeans: a naive batch implementation without considering OOM. Tested on NVIDIA H200 GPU with FP16 precision, 128 demensional data, varying number of clusters (k), data points (n) and batch size (b). Our Triton implementation brings significant performance improvements.   Note: fastkmeans(triton) get error when k=100 or k=1000 in figure 1. ### Large tensor Benchmark For large input that cannot fit in GPU memory, we compare the performance with fastkmeans(triton) with FP32 precision, 128 demensional data, number if data points scaling from 256K to 268M (N = 2^18, 2^20, 2^22, 2^24, 2^26, 2^28) with cluster counts following K = √N (512, 1024, 2048, 4096, 8192, 16384). Input tensor is generated randomly in CPU pinned memory. both flash-kmeans and fastkmeans transfer data from CPU to GPU in chunk and compute.  ### Large-D and dtype support The Triton assign kernel ships in two flavours and dispatches between them automatically based on input shape and dtype: - **Small-D path** (existing kernel, `D ≤ 512`): one program loads `x_tile (BN, D)` once and streams over centroids in `BLOCK_K` chunks. Per-arch heuristics (H200, H100, A100, GB10) pick `(BLOCK_N, BLOCK_K, num_warps, num_stages)` based on a hand-tuned table derived from `flash-kmeans-tune` grid sweeps. - **Split-D path** (new, `D > 512` or whenever the small-D kernel cannot fit shared memory at minimum tile size): outer K loop, inner D loop tiled by `BLOCK_D`. The cross accumulator `(BN, BK)` is held in registers across the D loop, so the K-streaming property (no `(B, N, K)` distance matrix materialised) is preserved. **Unknown GPUs**: when the GPU does not match any of the tuned families (H200/H100/A100/GB10), the wrapper unconditionally dispatches to the split-D kernel with a conservative fallback config (`BLOCK_N=32, BLOCK_K=32, BLOCK_D=32, num_warps=4, num_stages=1`). This avoids any reliance on per-arch tuning data we don't have, and the split-D path's `_fit_config_to_smem_split_d` post-process guarantees the launch fits shared memory regardless of the actual budget. Performance will be suboptimal on unfamiliar architectures — re-tune via [flash-kmeans-tune](https://github.com/svg-project/flash-kmeans-tune) to populate the corresponding `_heuristic_euclid_config_