Skip to content

Commit efecfc5

Browse files
guangy10Guang Yang
andauthored
Introduce 8da4w quant for decoder-only text models (#62)
* Introduce 8da4w quant for decoder-only text models * rebase on gemma3 ci and log pte file size * add support for embedding quantization --------- Co-authored-by: Guang Yang <[email protected]>
1 parent 2df9165 commit efecfc5

File tree

10 files changed

+265
-17
lines changed

10 files changed

+265
-17
lines changed

.github/workflows/test_models.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ jobs:
5050
with:
5151
python-version: ${{ matrix.python-version }}
5252
- name: Install dependencies for ExecuTorch
53+
# Consolidate torchao nightly version once https://github.com/pytorch/ao/issues/2157 is fixed
5354
run: |
5455
if [ "${{ matrix.executorch-version }}" == "nightly" ]; then
55-
export NIGHTLY_VERSION=dev20250413
56+
export NIGHTLY_VERSION=dev20250501
5657
pip install executorch==0.7.0.${NIGHTLY_VERSION} \
5758
torch==2.8.0.${NIGHTLY_VERSION} \
5859
torchvision==0.22.0.${NIGHTLY_VERSION} \
5960
torchaudio==2.6.0.${NIGHTLY_VERSION} \
61+
torchao==0.11.0.dev20250422 \
6062
--extra-index-url "https://download.pytorch.org/whl/nightly/cpu"
6163
else
6264
pip install executorch==${{ matrix.executorch-version }}

README.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ generated_text = model.text_generation(
129129
print(generated_text)
130130
```
131131

132-
## Supported Models and Backend
132+
## Supported Models
133133

134134
**Optimum-ExecuTorch** currently supports the following transformer models:
135135

@@ -166,20 +166,28 @@ We currently support a wide range of popular transformer models, including encod
166166
- [Pvt](https://huggingface.co/Zetatech/pvt-tiny-224): Pyramid Vision Transformer (tiny-sized)
167167
- [Swin](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224): Swin Transformer (tiny-sized)
168168

169-
🚀 More coming soon...
170-
171169
### Audio Models
172170
#### Encoder-decoder models
173171
- [Whisper](https://huggingface.co/openai/whisper-tiny): OpenAI's `Whisper` and its variants
174172

175173
*📌 Note: This list is continuously expanding. As we continue to expand support, more models will be added.*
176174

177-
**Supported Backend:**
178175

179-
Currently, **Optimum-ExecuTorch** supports only the [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack.html) for efficient CPU execution on mobile devices. Quantization support for XNNPACK is planned to be added shortly.
176+
## Supported Optimizations
177+
178+
### Custom Operators
179+
Supported using [custom SDPA](https://github.com/pytorch/executorch/blob/a4322c71c3a97e79e0454a8223db214b010f1193/extension/llm/README.md?plain=1#L40) with Hugging Face Transformers, boosting performance by 3x compared to default SDPA, based on tests with `HuggingFaceTB/SmolLM2-135M`.
180+
181+
### Backends Delegation
182+
Currently, **Optimum-ExecuTorch** supports the [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack.html) with [custom SDPA](https://github.com/pytorch/executorch/blob/a4322c71c3a97e79e0454a8223db214b010f1193/extension/llm/README.md?plain=1#L40) for efficient execution on mobile CPUs.
180183

181184
For a comprehensive overview of all backends supported by ExecuTorch, please refer to the [ExecuTorch Backend Overview](https://pytorch.org/executorch/main/backends-overview.html).
182185

186+
### Quantization
187+
We currently support Post-Training Quantization (PTQ) for linear layers using int8 dynamic per-token activations and int4 grouped per-channel weights (aka `8da4w`), as well as int8 channelwise embedding quantization.
188+
189+
🚀 Stay tuned as more optimizations and performance enhancements are coming soon!
190+
183191

184192
## 🛠️ Advanced Usage
185193

optimum/commands/export/executorch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ def parse_args_executorch(parser):
5757
action="store_true",
5858
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
5959
)
60+
required_group.add_argument(
61+
"--qlinear",
62+
required=False,
63+
action="store_true",
64+
help="Quantization config for linear layers. If set, defaults to '8da4w' w/ groupsize 32.",
65+
)
66+
required_group.add_argument(
67+
"--qembedding",
68+
required=False,
69+
action="store_true",
70+
help="Quantization config for embedding. If set, defaults to int8 channelwise.",
71+
)
6072

6173

6274
class ExecuTorchExportCommand(BaseOptimumCLICommand):
@@ -72,6 +84,10 @@ def run(self):
7284
kwargs = {}
7385
if self.args.use_custom_sdpa:
7486
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
87+
if self.args.qlinear:
88+
kwargs["qlinear"] = self.args.qlinear
89+
if self.args.qembedding:
90+
kwargs["qembedding"] = self.args.qembedding
7591

7692
main_export(
7793
model_name_or_path=self.args.model,

optimum/executorch/modeling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from transformers.utils import is_offline_mode
3838

3939
from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
40+
from executorch.kernels import quantized # noqa
4041

4142
from ..exporters import TasksManager
4243
from ..exporters.executorch import main_export
@@ -180,7 +181,9 @@ def _from_pretrained(
180181
local_files_only=local_files_only,
181182
)
182183
model = _load_for_executorch(model_cache_path)
183-
logging.info(f"Loaded model from {model_cache_path}")
184+
logging.info(
185+
f"Loaded model from {model_cache_path} ({os.path.getsize(model_cache_path) / (1024 * 1024):.2f} MB)"
186+
)
184187

185188
return {default_file_name.removesuffix(_PTE_SUFFIX): model}
186189

optimum/exporters/executorch/convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from .recipe_registry import discover_recipes, recipe_registry
2727

2828

29-
logger = logging.getLogger(__name__)
30-
3129
AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)
3230

3331

@@ -82,6 +80,8 @@ def export_to_executorch(
8280
full_path = os.path.join(f"{output_dir}", f"{name}.pte")
8381
with open(full_path, "wb") as f:
8482
prog.write_to_file(f)
85-
logger.info(f"Saved exported program to {full_path}")
83+
logging.info(
84+
f"Saved exported program to {full_path} ({os.path.getsize(full_path) / (1024 * 1024):.2f} MB)"
85+
)
8686

8787
return executorch_progs

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
import logging
1616
from typing import Dict, Union
1717

18+
from packaging.version import parse
19+
from tabulate import tabulate
1820
from torch.export import ExportedProgram
1921

22+
from executorch import version as executorch_version
2023
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
24+
from executorch.devtools.backend_debug import get_delegation_info
2125
from executorch.exir import (
2226
EdgeCompileConfig,
2327
ExecutorchBackendConfig,
@@ -60,7 +64,14 @@ def _lower_to_executorch(
6064
metadata=None,
6165
) -> Dict[str, ExecutorchProgram]:
6266
et_progs = {}
67+
backend_config_dict = {
68+
"extract_delegate_segments": True,
69+
}
70+
if parse(executorch_version.__version__).base_version > "0.6.0":
71+
backend_config_dict["do_quant_fusion_and_const_prop"] = True
72+
6373
for pte_name, exported_program in exported_programs.items():
74+
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
6475
et_progs[pte_name] = to_edge_transform_and_lower(
6576
exported_program,
6677
partitioner=[XnnpackPartitioner()],
@@ -69,11 +80,16 @@ def _lower_to_executorch(
6980
),
7081
constant_methods=metadata,
7182
).to_executorch(
72-
config=ExecutorchBackendConfig(
73-
extract_delegate_segments=True,
74-
),
83+
config=ExecutorchBackendConfig(**backend_config_dict),
84+
)
85+
logging.debug(
86+
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
87+
)
88+
delegation_info = get_delegation_info(et_progs[pte_name].exported_program().graph_module)
89+
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
90+
logging.debug(
91+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
7592
)
76-
logging.debug(f"Exported program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}")
7793
return et_progs
7894

7995
exported_progs = model.export()

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
16+
17+
import torch
18+
import torchao
19+
from packaging.version import parse
1520
from transformers import AutoModelForCausalLM, GenerationConfig
1621

1722
from ..integrations import CausalLMExportableModule
@@ -71,4 +76,48 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
7176
},
7277
),
7378
)
79+
80+
# TODO: Move quantization recipe out for better composability.
81+
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
82+
qlinear_config = kwargs.get("qlinear", None)
83+
qembedding_config = kwargs.get("qembedding", None)
84+
if qlinear_config or qembedding_config:
85+
# TODO: Update torchao to use 0.11.0 once released
86+
if parse(torchao.__version__) < parse("0.11.0.dev0"):
87+
raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")
88+
89+
from torchao.quantization.granularity import PerAxis, PerGroup
90+
from torchao.quantization.quant_api import (
91+
Int8DynamicActivationIntxWeightConfig,
92+
IntxWeightOnlyConfig,
93+
quantize_,
94+
)
95+
from torchao.utils import unwrap_tensor_subclass
96+
97+
if qembedding_config:
98+
logging.info("Quantizing embedding layers.")
99+
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
100+
embedding_config = IntxWeightOnlyConfig(
101+
weight_dtype=torch.int8,
102+
granularity=PerAxis(0),
103+
)
104+
quantize_(
105+
eager_model,
106+
embedding_config,
107+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
108+
)
109+
110+
if qlinear_config:
111+
logging.info("Quantizing linear layers.")
112+
linear_config = Int8DynamicActivationIntxWeightConfig(
113+
weight_dtype=torch.int4,
114+
weight_granularity=PerGroup(32),
115+
)
116+
quantize_(
117+
eager_model,
118+
linear_config,
119+
)
120+
121+
unwrap_tensor_subclass(eager_model)
122+
74123
return CausalLMExportableModule(eager_model)

tests/models/test_modeling_gemma3.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import unittest
2323

2424
import pytest
25+
import torchao
2526
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
27+
from packaging.version import parse
2628
from transformers import AutoTokenizer
2729
from transformers.testing_utils import slow
2830

@@ -167,3 +169,42 @@ def test_gemma3_text_generation_with_custom_sdpa_float16(self):
167169
gc.collect()
168170

169171
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
172+
173+
@slow
174+
@pytest.mark.run_slow
175+
@pytest.mark.skipif(
176+
parse(torchao.__version__) < parse("0.11.0.dev0"),
177+
reason="Only available on torchao >= 0.11.0.dev0",
178+
)
179+
def test_gemma3_text_generation_with_custom_sdpa_8da4w(self):
180+
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI
181+
# model_id = "google/gemma-3-1b-it"
182+
model_id = "unsloth/gemma-3-1b-it"
183+
prompt = "Write a poem about a machine learning."
184+
185+
# ExecuTorch model + custom sdpa + 8da4w linear quantization
186+
kwargs = {"qlinear": True}
187+
model = ExecuTorchModelForCausalLM.from_pretrained(
188+
model_id,
189+
recipe="xnnpack",
190+
attn_implementation="custom_sdpa",
191+
**kwargs,
192+
)
193+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
194+
self.assertIsInstance(model.model, ExecuTorchModule)
195+
196+
tokenizer = AutoTokenizer.from_pretrained(model_id)
197+
generated_text = model.text_generation(
198+
tokenizer=tokenizer,
199+
prompt=prompt,
200+
max_seq_len=64,
201+
)
202+
logging.info(f"\nGenerated text:\n\t{generated_text}")
203+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
204+
205+
# Free memory before loading eager for quality check
206+
del model
207+
del tokenizer
208+
gc.collect()
209+
210+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

tests/models/test_modeling_qwen3.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import unittest
2222

2323
import pytest
24+
import torchao
2425
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
26+
from packaging.version import parse
2527
from transformers import AutoTokenizer
2628
from transformers.testing_utils import slow
2729

@@ -136,3 +138,39 @@ def test_qwen3_text_generation_with_custom_sdpa_float16(self):
136138
gc.collect()
137139

138140
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
141+
142+
@slow
143+
@pytest.mark.run_slow
144+
@pytest.mark.skipif(
145+
parse(torchao.__version__) < parse("0.11.0.dev0"),
146+
reason="Only available on torchao >= 0.11.0.dev0",
147+
)
148+
def test_qwen3_text_generation_with_custom_sdpa_8da4w_8we(self):
149+
model_id = "Qwen/Qwen3-0.6B"
150+
prompt = "Give me a short introduction to large language model."
151+
tokenizer = AutoTokenizer.from_pretrained(model_id)
152+
153+
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
154+
kwargs = {"qlinear": True, "qembedding": True}
155+
model = ExecuTorchModelForCausalLM.from_pretrained(
156+
model_id,
157+
recipe="xnnpack",
158+
attn_implementation="custom_sdpa",
159+
**kwargs,
160+
)
161+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
162+
self.assertIsInstance(model.model, ExecuTorchModule)
163+
generated_text = model.text_generation(
164+
tokenizer=tokenizer,
165+
prompt=prompt,
166+
max_seq_len=128,
167+
)
168+
logging.info(f"\nGenerated text:\n\t{generated_text}")
169+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
170+
171+
# Free memory before loading eager for quality check
172+
del model
173+
del tokenizer
174+
gc.collect()
175+
176+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

0 commit comments

Comments
 (0)