This repository contains the implementation of ProxSparse, a learned method for semi-structured (2:4) pruning for Large Language Models with only hundreds of calibration data. ProxSparse does not involve additional weight updates once the mask is determined. You can find our paper (ICML'25) here.
The required environment to run ProxSparse is stored in requirement.txt
. You can run the below command to install them.
conda create --name proxsparse python==3.10
conda activate proxsparse
pip install -r requirement.txt
We release our 2:4 pruned models induced by ProxSparse in Huggingface repository aladinggit/proxsparse_models. The repo contains 2:4 pruned checkpoints of Llama-2-7b, Llama-2-13b, Llama-3.1-8b, Mistral-v0.1-7b, Mistral-v0.3-7b, Openllama-v2-7b and Qwen-2.5-14b.
The downloading scripts will download those pruned models from the huggingface repository.
python proxsparse_pruned_model_download.py
The eval/ppl.py
contains a standard evaluation for C4/wikitext perplexity. To run a quick evaluation on Wikitext perplexity on those checkpoints:
bash script/eval_pruned_prox_model.sh
Here we provide steps to prune models with ProxSparse. We use transformers
, accelerate
and trl
package as the main training framework.
The core ProxSparse operator can be found in end-to-end/prox_op.py
. If you want to integrate ProxSparse operator into different training framework, We put more information about the integration in end-to-end/doc.txt
.
We provide scripts to run ProxSparse. The script
directory contains scripts to learn mask, for example,
bash learn_mask_llama2_7b_prox.sh
will learn mask for llama-2-7b model. To learn with Qwen-2.5-14b model, run
bash learn_mask_qwen_2.5_14b_prox.sh
The description of the parameter can be found in the entry point end-to-end/main.py
. We quickly go through the meaning of the arguments here:
The *prox.sh will launch training with only the $\lambda{1}$ (semi-structured regularizer)
-
--model_dir and model_subdir
: the name of the model. -
--lambda
: the hyperparameter of the$\lambda_{1}$ denoting the strength of the semi-structured regularizer. -
--batch_size
: data batch size. -
--ctx_len
: the context length of the data used in training process. -
--samples
: number of data used in training process. -
--lr
: learning rate.
The *full.sh will launch training with both the $\lambda{1}$ (semi-structured regularizer) and
-
--lambda2_
: the hyperparameter of the$\lambda_{2}$ denoting the strength of the frozen weight regularizer. -
--project_lambda2
: set this to 0 (by default) for training only with$\lambda_{1}$ (semi-structured regularizer). Set this to 1 raining with both the$\lambda_{1}$ (semi-structured regularizer) and$\lambda_{2}$ (frozen weight regularizer). -
--epsilon
: the epsilon term in$\lambda_{2}$ to avoid numerical instability.
Each script contains three main functions: learning with ProxSparse, extracting the mask, applying and evaluating the mask. More information about the script can be found in script/doc.txt
. The optimal configuration for the different models has been set in the script and can be refered from our paper. To extract mask for further models that are not covered by those scripts, modify the model name and other corresponding configurations to run them.
The baseline directory is the implementation from Wanda repository. For more baselines (AdmmPrune,OWL and AlphaPrune), please refer to their original implementation for details!
This project welcomes contributions and suggestions, see CONTRIBUTING.md for details. This project has adopted the Amazon Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.
This project is licensed under the Apache 2.0 license
.
If you might find our work useful, please cite:
@article{liu2025proxsparse,
title={ProxSparse: Regularized Learning of Semi-Structured Sparsity Masks for Pretrained LLMs},
author={Liu, Hongyi and Saha, Rajarshi and Jia, Zhen and Park, Youngsuk and Huang, Jiaji and Sabach, Shoham and Wang, Yu-Xiang and Karypis, George},
journal={arXiv preprint arXiv:2502.00258},
year={2025}
}