Hi, thanks for open-sourcing MSA.
In Section 4.1 of the paper, the Top-k kernel is described as a heap-based implementation:
Each of the warp’s 32 lanes streams a 1/32 stride of the input row and maintains a k-element min-heap in shared memory. The heap root is cached in a register, and insertions are performed with deferred writes. Finally, a k-round shuffle merge combines the 32 local TopK results.
However, in the current codebase I could not find this heap-based Top-k implementation.
The exposed API:
from fmha_sm100 import sparse_topk_select
appears to call:
python/fmha_sm100/api.py::sparse_topk_select
python/fmha_sm100/csrc/sparse_topk_select.cu
python/fmha_sm100/csrc/include/sparse_topk_select.cuh
The implementation in sparse_topk_select.cuh seems to be based on TensorRT-LLM indexerTopK, using histogram/threshold selection plus insertion sort, rather than the per-lane min-heap + shuffle-merge algorithm described in the paper.
Could you clarify:
- Is the heap-based Top-k kernel from the paper included in this repository?
- If yes, where is the implementation located?
- If no, is the current sparse_topk_select implementation intended to replace the paper-described heap-based kernel?
- Are the benchmark numbers in the paper based on the heap-based kernel or the currently released sparse_topk_select kernel?
Thanks!
Hi, thanks for open-sourcing MSA.
In Section 4.1 of the paper, the Top-k kernel is described as a heap-based implementation:
However, in the current codebase I could not find this heap-based Top-k implementation.
The exposed API:
The implementation in sparse_topk_select.cuh seems to be based on TensorRT-LLM indexerTopK, using histogram/threshold selection plus insertion sort, rather than the per-lane min-heap + shuffle-merge algorithm described in the paper.
Could you clarify:
Thanks!