Skip to content

Commit db4ffe0

Browse files
committed
Release of MoE related code.
Co-authored-by: Shulai Zhang <[email protected]> Co-authored-by: Ningxin Zheng <[email protected]> Co-authored-by: Chengquan Jiang <[email protected]> Co-authored-by: Wenlei Bao <[email protected]> Co-authored-by: Qi Hou <[email protected]> Co-authored-by: Ziheng Jiang <[email protected]> Co-authored-by: Xin Liu <[email protected]> Co-authored-by: Liwen Chang <[email protected]> Co-authored-by: Haibin Lin <[email protected]>
1 parent 7f98c8a commit db4ffe0

File tree

354 files changed

+62845
-9804
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

354 files changed

+62845
-9804
lines changed

.clang-format

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
---
2+
Language: Cpp
3+
BasedOnStyle: Google
4+
IndentWidth: 2
5+
TabWidth: 2
6+
ColumnLimit: 99
7+
ContinuationIndentWidth: 4
8+
AccessModifierOffset: -1 # The private/protected/public has no indent in class
9+
Standard: c++17
10+
AllowShortBlocksOnASingleLine: false
11+
AllowShortCaseLabelsOnASingleLine: true
12+
AllowShortFunctionsOnASingleLine: true
13+
AllowShortIfStatementsOnASingleLine: false
14+
AllowShortLoopsOnASingleLine: false
15+
AllowAllParametersOfDeclarationOnNextLine: true
16+
BinPackParameters: false
17+
BinPackArguments: false
18+
AlignAfterOpenBracket: AlwaysBreak
19+
AlwaysBreakTemplateDeclarations: true
20+
AlwaysBreakAfterDefinitionReturnType: All
21+
DerivePointerAlignment: false
22+
PointerAlignment: Right
23+
24+
# clang-format 3.9+
25+
SortIncludes: false
26+
ReflowComments: true
27+
...

.gitignore

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# PyCache files
2+
build/
3+
.cache/
4+
tmp/
5+
report*.sqlite
6+
report*.nsys-rep
7+
8+
# run files
9+
log/
10+
prof/
11+
workspace/
12+
13+
# general things to ignore
14+
dist/
15+
*.egg-info/
16+
.eggs/
17+
*.egg
18+
*.py[cod]
19+
__pycache__/
20+
*.so
21+
*.so.*
22+
*~
23+
python/flux/version.py
24+
25+
# due to using tox and pytest and clangd
26+
.tox
27+
28+
# 3rdparty
29+
/3rdparty/nvshmem/

.gitmodules

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
[submodule "3rdparty/cutlass"]
2-
path = 3rdparty/cutlass
3-
url = https://github.com/NVIDIA/cutlass
41
[submodule "3rdparty/nccl"]
52
path = 3rdparty/nccl
63
url = https://github.com/NVIDIA/nccl
4+
[submodule "3rdparty/cutlass"]
5+
path = 3rdparty/cutlass
6+
url = https://github.com/NVIDIA/cutlass

3rdparty/cutlass3.7.patch

Lines changed: 2089 additions & 0 deletions
Large diffs are not rendered by default.

CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
2-
project(FLUX LANGUAGES CXX CUDA)
2+
project(FLUX LANGUAGES C CXX CUDA)
33

44
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/modules/")
55

@@ -15,6 +15,8 @@ message("PYTHONPATH: ${PYTHONPATH}")
1515
message("NVShmem Support: ${ENABLE_NVSHMEM}")
1616

1717
# find cuda
18+
# specify cuda path if other than default
19+
# set(CUDA_TOOLKIT_ROOT_DIR /path/to/installed/cuda)
1820
find_package(CUDAToolkit REQUIRED)
1921

2022
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
@@ -102,6 +104,10 @@ print(os.path.dirname(torch.__file__),end='');"
102104
find_package(Torch REQUIRED)
103105
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_DIR}/lib")
104106

107+
if(TORCH_CXX_FLAGS)
108+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
109+
endif()
110+
105111
execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "from __future__ import print_function; from distutils import sysconfig;
106112
print(sysconfig.get_python_inc());
107113
print(sysconfig.get_config_var('EXT_SUFFIX'));"
@@ -172,4 +178,12 @@ link_directories(
172178
${COMMON_LIB_DIRS}
173179
)
174180

181+
if (WITH_PROTOBUF)
182+
FIND_PACKAGE(Protobuf REQUIRED)
183+
add_subdirectory(proto)
184+
endif()
185+
175186
add_subdirectory(src)
187+
if (BUILD_TEST)
188+
add_subdirectory(test)
189+
endif()

MANIFEST.in

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
include src/ths_op/*.cc.inc
2-
exclude pynvshmem/
3-
recursive-include src *
4-
recursive-include include *
5-
recursive-include python/flux_ths_pybind
1+
global-exclude *.so*
2+
recursive-include python/flux/include *
3+
recursive-include python/flux/share *

NOTICE

Lines changed: 48 additions & 48 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,96 @@
11
# Flux
22

3-
Flux is a fast communication-overlapping library for tensor parallelism on GPUs.
3+
Flux is a communication-overlapping library for dense/MoE models on GPUs, providing high-performance and pluggable kernels to support various parallelisms in model training/inference.
44

5+
Flux's efficient kernels are compatible with Pytorch and can be integrated into existing frameworks easily, supporting various Nvidia GPU architectures and data types.
56

6-
## Why Flux
7+
Welcome to join the [Wechat](https://github.com/bytedance/flux/blob/main/docs/assets/comet_wechat_group.JPG) group and stay tuned!
78

8-
Flux can significantly reduce latency and increase throughput for tensor parallelism for both inference and training.
9+
## Getting started
10+
Install Flux either from source or from PyPI.
911

10-
## Install from pip
11-
```
12-
# Make sure that PyTorch is installed.
13-
pip install packaging
14-
pip install byte-flux
12+
### Install from Source
13+
```bash
14+
git clone --recursive https://github.com/bytedance/flux.git && cd flux
15+
16+
# Install dependencies
17+
bash ./install_deps.sh
18+
19+
# For Ampere(sm80) GPU
20+
./build.sh --arch 80 --nvshmem
21+
# For Ada Lovelace(sm89) GPU
22+
./build.sh --arch 89 --nvshmem
23+
# For Hopper(sm90) GPU
24+
./build.sh --arch 90 --nvshmem
1525
```
1626

17-
## Build from source
27+
#### Install in a virtual environment
28+
Here is a snippet to install Flux in a virtual environment. Let's finish the installation in an virtual environment with CUDA 12.4, torch 2.6.0 and python 3.11.
29+
1830
```bash
19-
git clone https://github.com/bytedance/flux.git
20-
git submodule update --init --recursive
21-
# Ampere
22-
./build.sh --arch 80
23-
# Hopper
24-
./build.sh --arch 90
31+
conda create -n flux python=3.11
32+
conda activate flux
33+
pip3 install packaging
34+
pip3 install ninja
35+
pip3 install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
36+
37+
./build.sh --clean-all
38+
./build.sh --arch "80;89;90" --nvshmem --package
2539
```
26-
## Build for cross-machine TP
27-
FLUX relies on NVSHMEM for communication across nodes. Therefore, if you need support for cross-machine tensor parallelism (TP), you must manually download the NVSHMEM source code and enable the nvshmem option during compilation.
40+
41+
Then you would expect a wheel package under `dist/` folder that is suitable for your virtual environment.
42+
43+
### Install from PyPI
44+
We also provide some pre-built wheels for Flux, and you can directly install with pip if your wanted version is available. Currently we provide wheels for the following configurations: torch(2.4.0, 2.5.0, 2.6.0), python(3.10, 3.11), cuda(12.4).
2845

2946
```bash
30-
git clone https://github.com/bytedance/flux.git
31-
# Download nvshmem-2.11(https://developer.nvidia.com/nvshmem) and place it to flux/3rdparty/nvshmem
32-
# Flux is temporarily dependent on a specific version of nvshmem (2.11).
33-
tar Jxvf nvshmem_src_2.11.0-5.txz
34-
mv nvshmem_src_2.11.0-5 ${YOUR_PATH}/flux/3rdparty/nvshmem
35-
git submodule update --init --recursive
36-
37-
# Ampere
38-
./build.sh --arch 80 --nvshmem
39-
# Hopper
40-
./build.sh --arch 90 --nvshmem
47+
# Make sure that PyTorch is installed.
48+
pip install byte-flux
4149
```
4250

43-
If you are tired of the cmake process, you can set environment variable `FLUX_BUILD_SKIP_CMAKE` to 1 to skip cmake if `build/CMakeCache.txt` already exists.
51+
### Customized Installation
52+
#### Build options for source installation
4453

45-
If you want to build a wheel package, add `--package` to the build command. find the output wheel file under dist/
54+
1. Add `--nvshmem` to build Flux with NVSHMEM support. It is essential for the MoE kernels.
55+
2. If you are tired of the cmake process, you can set environment variable `FLUX_BUILD_SKIP_CMAKE` to 1 to skip cmake if `build/CMakeCache.txt` already exists.
56+
3. If you want to build a wheel package, add `--package` to the build command. find the output wheel file under dist/
4657

47-
```bash
48-
# Ampere
49-
./build.sh --arch 80 --package
5058

51-
# Hopper
52-
./build.sh --arch 90 --package
53-
```
59+
#### Dependencies
60+
The core dependencies of Flux are NCCL, CUTLASS, and NVSHMEM, which are located under the 3rdparty folder.
61+
1. NCCL: Managed by git submodule automatically.
62+
2. NVSHMEM: Downloaded from https://developer.nvidia.com/nvshmem. The current version is 3.2.5-1.
63+
3. CUTLASS: Flux leverages CUTLASS to generate high-performance GEMM kernels. We currently use CUTLASS 3.7.0 and a tiny patch should be applied to CUTLASS.
64+
5465

66+
## Quick Start
5567

56-
## Run Demo
68+
Below are commands to run some basic demos once you have installed Flux successfully.
5769
```bash
5870
# gemm only
59-
PYTHONPATH=./python:$PYTHONPATH python3 test/test_gemm_only.py 4096 12288 6144 --dtype=float16
71+
python3 test/python/gemm_only/test_gemm_only.py 4096 12288 6144 --dtype=float16
6072

61-
# gemm fused with reduce-scatter
62-
./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
73+
# all-gather fused with gemm (dense MLP layer0)
74+
./launch.sh test/python/ag_gemm/test_ag_kernel.py 4096 49152 12288 --dtype=float16 --iters=10
6375

64-
# all-gather fused with gemm
65-
./scripts/launch.sh test/test_ag_kernel.py 4096 49152 12288 --dtype=float16 --iters=10
66-
```
76+
# gemm fused with reduce-scatter (dense MLP layer1)
77+
./launch.sh test/python/gemm_rs/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
6778

68-
## Performance
69-
We measured the examples from the above demo on both A800s and H800s. Each machine has 8 GPUs, with a TP size set to 8. The table below shows the performance comparison between flux and torch+nccl. It can be observed that by overlapping fine-grained computation and communication, Flux is able to effectively hide a significant portion of the communication time
79+
# all-gather fused with grouped gemm (MoE MLP layer0)
80+
./launch.sh test/python/moe_ag_scatter/test_moe_ag.py
7081

71-
| | M | K | N | Torch Gemm | Torch NCCL | Torch Total | Flux Gemm | Flux Comm | Flux Total |
72-
|----------|----------|----------|----------|----------|----------|----------|----------|----------|-----------|
73-
| AG+Gemm(A800) | 4096 | 12288 | 49152 | 2.438ms | 0.662ms | 3.099ms | 2.378ms | 0.091ms | 2.469ms |
74-
| Gemm+RS(A800) | 4096 | 49152 | 12288 | 2.453ms | 0.646ms | 3.100ms | 2.429ms | 0.080ms | 2.508ms |
75-
| AG+Gemm(H800) | 4096 | 12288 | 49152 | 0.846ms | 0.583ms | 1.429ms | 0.814ms | 0.143ms | 0.957ms |
76-
| Gemm+RS(H800) | 4096 | 49152 | 12288 | 0.818ms | 0.590ms | 1.408ms | 0.822ms | 0.111ms | 0.932ms |
82+
# grouped gemm fused with reduce-scatter (MoE MLP layer1)
83+
./launch.sh test/python/moe_gather_rs/test_moe_gather_rs.py
84+
```
85+
86+
You can check out the documentations for more details!
7787

78-
AG refers to AllGather.
79-
RS refers to ReduceScatter.
88+
* For a more detailed usage on MoE kernels, please refer to [Flux MoE Usage](https://github.com/bytedance/flux/blob/main/docs/moe_usage.md). Try some [examples](https://github.com/bytedance/flux/blob/main/examples) as a quick start. A [minimal MoE layer](https://github.com/bytedance/flux/blob/main/examples/moe_flux_only.py) can be implemented within only a few tens of lines of code using Flux!
89+
* For some performance numbers, please refer to [Performance Doc](https://github.com/bytedance/flux/blob/main/docs/performance.md).
90+
* To learn more about the design principles of Flux, please refer to [Design Doc](https://github.com/bytedance/flux/blob/main/docs/design.md).
8091

8192

82-
## Citing
93+
## Citations
8394

8495
If you use Flux in a scientific publication, we encourage you to add the following reference
8596
to the related papers:
@@ -92,11 +103,22 @@ to the related papers:
92103
archivePrefix={arXiv},
93104
primaryClass={cs.LG}
94105
}
106+
107+
@misc{zhang2025comet,
108+
title={Comet: Fine-grained Computation-communication Overlapping for Mixture-of-Experts},
109+
author={Shulai Zhang, Ningxin Zheng, Haibin Lin, Ziheng Jiang, Wenlei Bao, Chengquan Jiang, Qi Hou, Weihao Cui, Size Zheng, Li-Wen Chang, Quan Chen and Xin Liu},
110+
year={2025},
111+
eprint={2502.19811},
112+
archivePrefix={arXiv},
113+
primaryClass={cs.DC}
114+
}
115+
95116
```
96117

97118
## Reference
98119

99-
* [ArXiv Paper](http://arxiv.org/abs/2406.06858)
120+
* [ArXiv Paper (Flux)](http://arxiv.org/abs/2406.06858)
121+
* [ArXiv Paper (Comet)](https://arxiv.org/abs/2502.19811)
100122

101123
## [License](./LICENSE)
102124

0 commit comments

Comments
 (0)