-
Notifications
You must be signed in to change notification settings - Fork 871
Optimize low latency combine recv kernel (about 3.0x speedup) #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
(cherry picked from commit df72cff)
# Conflicts: # tests/test_intranode.py # tests/utils.py
brainstorm: maybe we can carry information during dispatch, such that when doing combine, we have a buffer of shape (num_tokens, num_topk, hidden) and directly send data into it. Then we do not need to read topk_idx and save time. Also save a bit of memory. share this first instead of implementing it b/c want to know whether this conflicts w/ any ongoing changes? EDIT: btw another minor update is to fully use 1024 threads s.t. wave num reduce by one |
EDIT: The code is ready and only needs cleanup. Since LyricZhao is too busy recently, I will do code cleanup when lyriczhao has time to merge PRs.
For reviewers: Please do not merge this PR directly, since it contains too many unrelated changes. Just ping me and I will split it into correct pieces.
Note that the code is pretty ugly messy hacky. I will cleanup code later if you think the PR looks acceptable. The real (not cleaned yet) code is only in internode_ll.cu, and this PR also contains many unrelated code on other PRs.
WARN: maybe I did something wrong since pretty tired now. I will recheck everything later when having some time, and also do other optimizations.
I personally care about combine-recv kernel separately from combine-send kernel, b/c I try to overlap the latter with gemm, while the former (for simplicity) may be directly executed.
before: 47-60us
(old)
after: 25-30us(old)
after: 21-26usafter: 18.9-19.4usafter: 17.5-17.7us
3.04x speedup