Skip to content

Files

Latest commit

735e32f · Jul 12, 2025

History

History
160 lines (88 loc) · 6.7 KB

README.md

File metadata and controls

160 lines (88 loc) · 6.7 KB

BlaGPT

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

BlaGPT Model

BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.

Multi-token prediction - link

Weight tying - link

Grouped query attention - link

Capping logits - link

QKV bias - link

Zero-init projection layer - link

Post and pre-RMSNorm - link

Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324

Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527

KV-Shifting attention - link - seems to improve performance - loss: 3.3310 -> 3.3138 - peak memory consumption: 42858 MiB

Dilated Attention (LongNet) - link

Multi-Head Latent Attention - link - loss: 3.3479 - peak memory consumption: 42192 MiB

Per token output bias - link - loss: 3.3257 - peak memory consumption: 42120 MiB

DyT Norm - link - didn't really work. Loss stuck too high

Forgetting Transformer (Vanilla and Pro vers) - link - vanilla loss: 3.3243, pro loss: OOM

Multi-Token Attention - link - loss: 3.3357 - peak memory: 42136 MiB

Differential Attention - link - loss: 3.3352 - peak memory: 41521 MiB

Softpick - link - loss: 3.3446 - peak memory: 59417 MiB

Canon Layer - link - loss: 3.3217 - peak memory: 43199 MiB

Parallel Transformer Block - link - loss: 3.3473 - peak memory: 40302 MiB

Per Layer Token Embedding - link - loss: 3.2411 - peak memory: 40916 MiB

Other Models

MegaByte - link - loss: 3.810

FTP (heavily modified) - link - loss: 3.901

Rene - link - loss: 3.340

Rwkv7 - link - loss: 4.450

Zamba2 - link - Zamba2 > Rene > Rwkv7

Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710

Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505

Tokenformer (in BlaGPT model) - link - loss: 3.390

LLaDa (dLLM) - link - val-loss: 8.6930, xentropy-loss: 4.2891 (comparable to other models and estimated by llada_validation_cross_entropy.py),

Avey - link - loss: 3.323, peak memory: 51962 MiB (batch size 8), step_time: 2871ms (very slow to train and uses >3x more memory than other models)

LFM2 - link - TBD

Byte-Level Models

Hourglass Transformer (modified) - link - val_loss:1.0048 train_time:2671049ms step_avg:524.76ms

AUNet - link - val_loss:1.1502 train_time:7246104ms step_avg:1423.60ms

SpaceByte - link - val_loss:1.6755 train_time:2154923ms step_avg:423.36ms peak memory consumption: 27781 MiB

HNet - link - val_loss:1.4554 train_time:2207809ms step_avg:433.75ms peak memory consumption: 23948 MiB

Optimizers

PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results

Ademamix - link - Unstable even after trying different learning rates.

Adopt - link - straight up Nan

CAdamW - link - loss: 3.3517

AdamW with independent weight decay - link - loss: 3.320

Adam - loss: 3.3224

AdamW - loss: 3.3310, peak VRAM: 42053 MiB, step_time: 533ms

DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676, peak VRAM: 41534 MiB, step_time: 820ms

Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324, peak VRAM: 41534 MiB, step_time: 610ms

MARS - link - loss: 3.3459, peak VRAM: 40953 MiB, step_time: 628ms

Muon - link - loss: 3.2923, peak VRAM: 40332MB, step_time: 620.24ms

BiClip - link - (not working well) loss: 7.2292, peak VRAM: 39751 MiB, step_time: 510ms

Adding a New Model

  • Implement the model
  • Return the loss in the forward function
  • Add model to model_registry.py
  • And start training

See one of the implementations for details.

Training

  • Get the data by running data/fineweb10B_cached.py

  • Start training with:

torchrun --standalone --nproc_per_node=8 train.py --run_name pre_post_norm --model_name blagpt
  • (Optional) Run the learning rate finder before the training
torchrun --standalone --nproc_per_node=8 find_lr.py --model_name blagpt

# Output
Results:
Steepest gradient learning rate: 3.31e-06
Elbow point learning rate: 1.20e-01
Plot saved to: logs/lr_finder_blagpt/lr_finder_plot.png
Results saved to: logs/lr_finder_blagpt/lr_finder_results.pt

Best Model So Far

  • Check best_model_config.py for the best model configuration so far.

  • You can run the training with the best model config by running:

torchrun --standalone --nproc_per_node=8 train.py --run_name best_model --model_name best

Acknowledgements

The initial code is based on

Nano GPT - link

Modded NanoGPT - link

Thanks to @xumingyu2021 for memory friendly implementation of the Differential Attention