Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.
BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.
Multi-token prediction - link
Weight tying - link
Grouped query attention - link
Capping logits - link
QKV bias - link
Zero-init projection layer - link
Post and pre-RMSNorm - link
Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324
Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527
KV-Shifting attention - link - seems to improve performance - loss: 3.3310
-> 3.3138
- peak memory consumption: 42858 MiB
Dilated Attention (LongNet) - link
Multi-Head Latent Attention - link - loss: 3.3479
- peak memory consumption: 42192 MiB
Per token output bias - link - loss: 3.3257
- peak memory consumption: 42120 MiB
DyT Norm - link - didn't really work. Loss stuck too high
Forgetting Transformer (Vanilla and Pro vers) - link - vanilla loss: 3.3243
, pro loss: OOM
Multi-Token Attention - link - loss: 3.3357
- peak memory: 42136 MiB
Differential Attention - link - loss: 3.3352
- peak memory: 41521 MiB
Softpick - link - loss: 3.3446
- peak memory: 59417 MiB
Canon Layer - link - loss: 3.3217
- peak memory: 43199 MiB
Parallel Transformer Block - link - loss: 3.3473
- peak memory: 40302 MiB
Per Layer Token Embedding - link - loss: 3.2411
- peak memory: 40916 MiB
MegaByte - link - loss: 3.810
FTP (heavily modified) - link - loss: 3.901
Rene - link - loss: 3.340
Rwkv7 - link - loss: 4.450
Zamba2 - link - Zamba2 > Rene > Rwkv7
Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710
Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505
Tokenformer (in BlaGPT model) - link - loss: 3.390
LLaDa (dLLM) - link - val-loss: 8.6930
, xentropy-loss: 4.2891
(comparable to other models and estimated by llada_validation_cross_entropy.py
),
Avey - link - loss: 3.323
, peak memory: 51962 MiB
(batch size 8), step_time: 2871ms
(very slow to train and uses >3x more memory than other models)
PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results
Ademamix - link - Unstable even after trying different learning rates.
Adopt - link - straight up Nan
CAdamW - link - loss: 3.3517
AdamW with independent weight decay - link - loss: 3.320
Adam - loss: 3.3224
AdamW - loss: 3.3310
, peak VRAM: 42053 MiB
, step_time: 533ms
DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676
, peak VRAM: 41534 MiB
, step_time: 820ms
Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324
, peak VRAM: 41534 MiB
, step_time: 610ms
MARS - link - loss: 3.3459
, peak VRAM: 40953 MiB, step_time: 628ms
Muon - link - loss: 3.2923
, peak VRAM: 40332MB
, step_time: 620.24ms
BiClip - link - (not working well) loss: 7.2292
, peak VRAM: 39751 MiB
, step_time: 510ms
- Implement the model
- Return the loss in the forward function
- Add model to
model_registry.py
- And start training
See one of the implementations for details.
-
Get the data by running
data/fineweb10B_cached.py
-
Start training with:
torchrun --standalone --nproc_per_node=8 train.py --run_name pre_post_norm --model_name blagpt
- (Optional) Run the learning rate finder before the training
torchrun --standalone --nproc_per_node=8 find_lr.py --model_name blagpt
# Output
Results:
Steepest gradient learning rate: 3.31e-06
Elbow point learning rate: 1.20e-01
Plot saved to: logs/lr_finder_blagpt/lr_finder_plot.png
Results saved to: logs/lr_finder_blagpt/lr_finder_results.pt
-
Check
best_model_config.py
for the best model configuration so far. -
You can run the training with the best model config by running:
torchrun --standalone --nproc_per_node=8 train.py --run_name best_model --model_name best
The initial code is based on
Nano GPT - link
Modded NanoGPT - link
Thanks to @xumingyu2021 for memory friendly implementation of the Differential Attention