Skip to content

Optimize low latency combine recv kernel (about 3.0x speedup) #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 120 commits into
base: main
Choose a base branch
from

Conversation

fzyzcjy
Copy link
Contributor

@fzyzcjy fzyzcjy commented Jun 23, 2025

EDIT: The code is ready and only needs cleanup. Since LyricZhao is too busy recently, I will do code cleanup when lyriczhao has time to merge PRs.


For reviewers: Please do not merge this PR directly, since it contains too many unrelated changes. Just ping me and I will split it into correct pieces.


Note that the code is pretty ugly messy hacky. I will cleanup code later if you think the PR looks acceptable. The real (not cleaned yet) code is only in internode_ll.cu, and this PR also contains many unrelated code on other PRs.

WARN: maybe I did something wrong since pretty tired now. I will recheck everything later when having some time, and also do other optimizations.

I personally care about combine-recv kernel separately from combine-send kernel, b/c I try to overlap the latter with gemm, while the former (for simplicity) may be directly executed.

before: 47-60us

[rank 1] Dispatch + combine bandwidth: 509.39 GB/s, avg_t=261.84 us, min_t=257.63 us, max_t=267.97 us
[rank 3] Dispatch + combine bandwidth: 508.68 GB/s, avg_t=262.20 us, min_t=257.98 us, max_t=266.98 us
[rank 2] Dispatch + combine bandwidth: 510.61 GB/s, avg_t=261.21 us, min_t=256.29 us, max_t=267.87 us
[rank 0] Dispatch + combine bandwidth: 510.74 GB/s, avg_t=261.15 us, min_t=256.77 us, max_t=267.07 us
[rank 3] Dispatch bandwidth: 498.46 GB/s, avg_t=91.16 us | Combine bandwidth: 545.72 GB/s, avg_t=161.14 us
[rank 1] Dispatch bandwidth: 515.51 GB/s, avg_t=88.15 us | Combine bandwidth: 534.73 GB/s, avg_t=164.45 us
[rank 2] Dispatch bandwidth: 505.06 GB/s, avg_t=89.97 us | Combine bandwidth: 539.39 GB/s, avg_t=163.03 us
[rank 0] Dispatch bandwidth: 510.75 GB/s, avg_t=88.97 us | Combine bandwidth: 536.92 GB/s, avg_t=163.78 us
[rank 0] Dispatch send/recv time: 87.58 = 67.97 + 19.61 us | Combine send/recv time: 163.52 = 103.50 + 60.02 us
[rank 1] Dispatch send/recv time: 86.07 = 67.26 + 18.81 us | Combine send/recv time: 153.79 = 100.97 + 52.82 us
[rank 3] Dispatch send/recv time: 89.06 = 69.99 + 19.07 us | Combine send/recv time: 147.44 = 99.63 + 47.81 us
[rank 2] Dispatch send/recv time: 87.23 = 69.46 + 17.77 us | Combine send/recv time: 160.71 = 101.15 + 59.56 us

(old) after: 25-30us

[rank 2] Dispatch + combine bandwidth: 565.85 GB/s, avg_t=235.71 us, min_t=231.68 us, max_t=238.85 us
[rank 0] Dispatch + combine bandwidth: 566.71 GB/s, avg_t=235.35 us, min_t=231.04 us, max_t=242.14 us
[rank 1] Dispatch + combine bandwidth: 567.44 GB/s, avg_t=235.05 us, min_t=228.22 us, max_t=243.97 us
[rank 3] Dispatch + combine bandwidth: 565.72 GB/s, avg_t=235.77 us, min_t=231.33 us, max_t=238.72 us
[rank 0] Dispatch bandwidth: 510.00 GB/s, avg_t=89.10 us | Combine bandwidth: 633.10 GB/s, avg_t=138.90 us
[rank 2] Dispatch bandwidth: 508.28 GB/s, avg_t=89.40 us | Combine bandwidth: 634.00 GB/s, avg_t=138.70 us
[rank 3] Dispatch bandwidth: 503.92 GB/s, avg_t=90.17 us | Combine bandwidth: 639.68 GB/s, avg_t=137.47 us
[rank 1] Dispatch bandwidth: 508.36 GB/s, avg_t=89.39 us | Combine bandwidth: 637.05 GB/s, avg_t=138.04 us
[rank 1] Dispatch send/recv time: 88.20 = 69.39 + 18.81 us | Combine send/recv time: 126.42 = 101.29 + 25.13 us
[rank 2] Dispatch send/recv time: 87.61 = 69.31 + 18.29 us | Combine send/recv time: 132.13 = 101.60 + 30.52 us
[rank 3] Dispatch send/recv time: 88.29 = 69.07 + 19.23 us | Combine send/recv time: 125.68 = 99.55 + 26.13 us
[rank 0] Dispatch send/recv time: 86.82 = 67.39 + 19.43 us | Combine send/recv time: 133.65 = 102.91 + 30.73 us

(old) after: 21-26us

[rank 0] Dispatch + combine bandwidth: 560.83 GB/s, avg_t=237.82 us, min_t=233.95 us, max_t=242.56 us
[rank 2] Dispatch + combine bandwidth: 560.74 GB/s, avg_t=237.86 us, min_t=232.93 us, max_t=241.82 us
[rank 1] Dispatch + combine bandwidth: 559.85 GB/s, avg_t=238.24 us, min_t=232.26 us, max_t=242.37 us
[rank 3] Dispatch + combine bandwidth: 562.06 GB/s, avg_t=237.30 us, min_t=232.45 us, max_t=242.88 us
[rank 3] Dispatch bandwidth: 501.37 GB/s, avg_t=90.63 us | Combine bandwidth: 674.98 GB/s, avg_t=130.28 us
[rank 2] Dispatch bandwidth: 506.90 GB/s, avg_t=89.64 us | Combine bandwidth: 663.74 GB/s, avg_t=132.49 us
[rank 0] Dispatch bandwidth: 515.21 GB/s, avg_t=88.20 us | Combine bandwidth: 659.64 GB/s, avg_t=133.31 us
[rank 1] Dispatch bandwidth: 511.92 GB/s, avg_t=88.77 us | Combine bandwidth: 667.82 GB/s, avg_t=131.68 us
[rank 1] Dispatch send/recv time: 87.61 = 68.74 + 18.86 us | Combine send/recv time: 124.43 = 100.87 + 23.56 us
[rank 2] Dispatch send/recv time: 87.67 = 18.31 + 69.80 us | Combine send/recv time: 126.61 = 101.33 + 25.28 us
[rank 0] Dispatch send/recv time: 87.60 = 67.58 + 20.02 us | Combine send/recv time: 129.24 = 103.40 + 25.83 us
[rank 3] Dispatch send/recv time: 88.53 = 69.08 + 19.45 us | Combine send/recv time: 119.88 = 98.93 + 20.95 us

after: 18.9-19.4us

[rank 2] Dispatch + combine bandwidth: 579.30 GB/s, avg_t=230.24 us, min_t=224.48 us, max_t=235.04 us
[rank 1] Dispatch + combine bandwidth: 577.42 GB/s, avg_t=230.99 us, min_t=225.92 us, max_t=238.27 us
[rank 0] Dispatch + combine bandwidth: 577.39 GB/s, avg_t=231.00 us, min_t=226.11 us, max_t=234.88 us
[rank 3] Dispatch + combine bandwidth: 579.84 GB/s, avg_t=230.02 us, min_t=224.70 us, max_t=237.82 us
[rank 0] Dispatch bandwidth: 516.30 GB/s, avg_t=88.01 us | Combine bandwidth: 697.62 GB/s, avg_t=126.05 us
[rank 2] Dispatch bandwidth: 508.33 GB/s, avg_t=89.39 us | Combine bandwidth: 703.44 GB/s, avg_t=125.01 us
[rank 1] Dispatch bandwidth: 512.04 GB/s, avg_t=88.74 us | Combine bandwidth: 703.73 GB/s, avg_t=124.96 us
[rank 3] Dispatch bandwidth: 505.60 GB/s, avg_t=89.88 us | Combine bandwidth: 702.96 GB/s, avg_t=125.10 us
[rank 0] Dispatch send/recv time: 85.99 = 66.52 + 19.47 us | Combine send/recv time: 121.71 = 102.72 + 19.00 us
[rank 1] Dispatch send/recv time: 87.27 = 68.54 + 18.73 us | Combine send/recv time: 120.28 = 101.04 + 19.24 us
[rank 2] Dispatch send/recv time: 87.15 = 68.95 + 18.20 us | Combine send/recv time: 119.64 = 100.71 + 18.94 us
[rank 3] Dispatch send/recv time: 87.94 = 68.78 + 19.17 us | Combine send/recv time: 119.07 = 99.69 + 19.37 us

after: 17.5-17.7us
3.04x speedup

[rank 1] Dispatch + combine bandwidth: 591.47 GB/s, avg_t=225.50 us, min_t=218.24 us, max_t=230.11 us
[rank 0] Dispatch + combine bandwidth: 591.01 GB/s, avg_t=225.68 us, min_t=219.97 us, max_t=230.59 us
[rank 2] Dispatch + combine bandwidth: 589.63 GB/s, avg_t=226.20 us, min_t=219.04 us, max_t=233.57 us
[rank 3] Dispatch + combine bandwidth: 588.74 GB/s, avg_t=226.55 us, min_t=221.34 us, max_t=232.29 us
[rank 3] Dispatch bandwidth: 501.61 GB/s, avg_t=90.59 us | Combine bandwidth: 739.42 GB/s, avg_t=118.93 us
[rank 0] Dispatch bandwidth: 502.56 GB/s, avg_t=90.42 us | Combine bandwidth: 741.52 GB/s, avg_t=118.59 us
[rank 2] Dispatch bandwidth: 510.22 GB/s, avg_t=89.06 us | Combine bandwidth: 727.57 GB/s, avg_t=120.86 us
[rank 1] Dispatch bandwidth: 516.93 GB/s, avg_t=87.91 us | Combine bandwidth: 723.02 GB/s, avg_t=121.62 us
[rank 1] Dispatch send/recv time: 86.35 = 67.66 + 18.69 us | Combine send/recv time: 118.72 = 101.16 + 17.56 us
[rank 0] Dispatch send/recv time: 88.51 = 68.58 + 19.92 us | Combine send/recv time: 118.35 = 100.85 + 17.50 us
[rank 2] Dispatch send/recv time: 87.44 = 68.79 + 18.66 us | Combine send/recv time: 118.51 = 100.82 + 17.68 us
[rank 3] Dispatch send/recv time: 88.14 = 69.25 + 18.89 us | Combine send/recv time: 117.99 = 100.31 + 17.68 us

@fzyzcjy
Copy link
Contributor Author

fzyzcjy commented Jul 5, 2025

brainstorm: maybe we can carry information during dispatch, such that when doing combine, we have a buffer of shape (num_tokens, num_topk, hidden) and directly send data into it. Then we do not need to read topk_idx and save time. Also save a bit of memory.

share this first instead of implementing it b/c want to know whether this conflicts w/ any ongoing changes?

EDIT: btw another minor update is to fully use 1024 threads s.t. wave num reduce by one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants