Skip to content

Commit a2de3dc

Browse files
author
Guang Yang
committed
add support for embedding quantization
1 parent 4756ed7 commit a2de3dc

File tree

9 files changed

+152
-41
lines changed

9 files changed

+152
-41
lines changed

.github/workflows/test_models.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ jobs:
5050
with:
5151
python-version: ${{ matrix.python-version }}
5252
- name: Install dependencies for ExecuTorch
53+
# Consolidate torchao nightly version once https://github.com/pytorch/ao/issues/2157 is fixed
5354
run: |
5455
if [ "${{ matrix.executorch-version }}" == "nightly" ]; then
55-
export NIGHTLY_VERSION=dev20250422
56+
export NIGHTLY_VERSION=dev20250501
5657
pip install executorch==0.7.0.${NIGHTLY_VERSION} \
5758
torch==2.8.0.${NIGHTLY_VERSION} \
5859
torchvision==0.22.0.${NIGHTLY_VERSION} \
5960
torchaudio==2.6.0.${NIGHTLY_VERSION} \
60-
torchao==0.11.0.${NIGHTLY_VERSION} \
61+
torchao==0.11.0.dev20250422 \
6162
--extra-index-url "https://download.pytorch.org/whl/nightly/cpu"
6263
else
6364
pip install executorch==${{ matrix.executorch-version }}

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ generated_text = model.text_generation(
129129
print(generated_text)
130130
```
131131

132-
## Supported Models and Backend
132+
## Supported Models
133133

134134
**Optimum-ExecuTorch** currently supports the following transformer models:
135135

@@ -174,9 +174,9 @@ We currently support a wide range of popular transformer models, including encod
174174

175175
*📌 Note: This list is continuously expanding. As we continue to expand support, more models will be added.*
176176

177-
**Supported Backend:**
177+
## Supported Optimizations
178178

179-
Currently, **Optimum-ExecuTorch** supports only the [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack.html) for efficient execution on mobile CPUs. We currently support Post-Training Quantization (PTQ) for linear layers using int8 dynamic per-token activations and int4 grouped per-channel weights (`8da4w`).
179+
Currently, **Optimum-ExecuTorch** supports the [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack.html) with [custom SDPA](https://github.com/pytorch/executorch/blob/a4322c71c3a97e79e0454a8223db214b010f1193/extension/llm/README.md?plain=1#L40) for efficient execution on mobile CPUs. We currently support Post-Training Quantization (PTQ) for linear layers using int8 dynamic per-token activations and int4 grouped per-channel weights (`8da4w`), and int8 channelwise embedding quantization.
180180

181181
For a comprehensive overview of all backends supported by ExecuTorch, please refer to the [ExecuTorch Backend Overview](https://pytorch.org/executorch/main/backends-overview.html).
182182

optimum/commands/export/executorch.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,16 @@ def parse_args_executorch(parser):
5858
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
5959
)
6060
required_group.add_argument(
61-
"-qmode",
62-
"--quantization_mode",
61+
"--qlinear",
6362
required=False,
64-
choices=["8da4w"],
65-
help="Quantization recipe to use. Defaults to None.",
63+
action="store_true",
64+
help="Quantization config for linear layers. If set, defaults to '8da4w' w/ groupsize 32.",
65+
)
66+
required_group.add_argument(
67+
"--qembedding",
68+
required=False,
69+
action="store_true",
70+
help="Quantization config for embedding. If set, defaults to int8 channelwise.",
6671
)
6772

6873

@@ -79,8 +84,10 @@ def run(self):
7984
kwargs = {}
8085
if self.args.use_custom_sdpa:
8186
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
82-
if self.args.quantization_mode:
83-
kwargs["quantization_mode"] = self.args.quantization_mode
87+
if self.args.qlinear:
88+
kwargs["qlinear"] = self.args.qlinear
89+
if self.args.qembedding:
90+
kwargs["qembedding"] = self.args.qembedding
8491

8592
main_export(
8693
model_name_or_path=self.args.model,

optimum/executorch/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from transformers.utils import is_offline_mode
3838

3939
from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
40+
from executorch.kernels import quantized # noqa
4041

4142
from ..exporters import TasksManager
4243
from ..exporters.executorch import main_export

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import logging
1616
from typing import Dict, Union
1717

18+
from packaging.version import parse
1819
from tabulate import tabulate
1920
from torch.export import ExportedProgram
2021

22+
from executorch import version as executorch_version
2123
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
2224
from executorch.devtools.backend_debug import get_delegation_info
2325
from executorch.exir import (
@@ -62,6 +64,12 @@ def _lower_to_executorch(
6264
metadata=None,
6365
) -> Dict[str, ExecutorchProgram]:
6466
et_progs = {}
67+
backend_config_dict = {
68+
"extract_delegate_segments": True,
69+
}
70+
if parse(executorch_version.__version__).base_version > "0.6.0":
71+
backend_config_dict["do_quant_fusion_and_const_prop"] = True
72+
6573
for pte_name, exported_program in exported_programs.items():
6674
et_progs[pte_name] = to_edge_transform_and_lower(
6775
exported_program,
@@ -71,10 +79,7 @@ def _lower_to_executorch(
7179
),
7280
constant_methods=metadata,
7381
).to_executorch(
74-
config=ExecutorchBackendConfig(
75-
do_quant_fusion_and_const_prop=True,
76-
extract_delegate_segments=True,
77-
),
82+
config=ExecutorchBackendConfig(**backend_config_dict),
7883
)
7984
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
8085
logging.debug(

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
16+
1517
import torch
1618
import torchao
1719
from packaging.version import parse
@@ -57,14 +59,12 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
5759
cache_implementation = kwargs.get("cache_implementation", "static")
5860
max_length = kwargs.get("max_length", 2048)
5961
config = kwargs.get("config", None)
60-
quantization_mode = kwargs.get("quantization_mode", None)
6162

6263
eager_model = AutoModelForCausalLM.from_pretrained(
6364
model_name_or_path,
6465
device_map=device,
6566
torch_dtype=dtype,
6667
config=config,
67-
# quantization_config=quantization_config,
6868
attn_implementation=attn_implementation,
6969
generation_config=GenerationConfig(
7070
use_cache=True,
@@ -77,24 +77,47 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
7777
),
7878
)
7979

80-
if quantization_mode == "8da4w":
80+
# TODO: Move quantization recipe out for better composability.
81+
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
82+
qlinear_config = kwargs.get("qlinear", None)
83+
qembedding_config = kwargs.get("qembedding", None)
84+
if qlinear_config or qembedding_config:
85+
# TODO: Update torchao to use 0.11.0 once released
8186
if parse(torchao.__version__) < parse("0.11.0.dev0"):
8287
raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")
8388

84-
from torchao.quantization.granularity import PerGroup
89+
from torchao.quantization.granularity import PerAxis, PerGroup
8590
from torchao.quantization.quant_api import (
8691
Int8DynamicActivationIntxWeightConfig,
92+
IntxWeightOnlyConfig,
93+
quantize_,
8794
)
95+
from torchao.utils import unwrap_tensor_subclass
8896

89-
# TODO: Should switch to TorchAoConfig once the quant issue on final lm_head layer is fixed.
90-
linear_config = Int8DynamicActivationIntxWeightConfig(
91-
weight_dtype=torch.int4,
92-
weight_granularity=PerGroup(64),
93-
)
97+
if qembedding_config:
98+
logging.info("Quantizing embedding layers.")
99+
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
100+
embedding_config = IntxWeightOnlyConfig(
101+
weight_dtype=torch.int8,
102+
granularity=PerAxis(0),
103+
)
104+
quantize_(
105+
eager_model,
106+
embedding_config,
107+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
108+
)
94109

95-
torchao.quantize_(
96-
eager_model,
97-
linear_config,
98-
)
110+
if qlinear_config:
111+
logging.info("Quantizing linear layers.")
112+
linear_config = Int8DynamicActivationIntxWeightConfig(
113+
weight_dtype=torch.int4,
114+
weight_granularity=PerGroup(32),
115+
)
116+
quantize_(
117+
eager_model,
118+
linear_config,
119+
)
120+
121+
unwrap_tensor_subclass(eager_model)
99122

100123
return CausalLMExportableModule(eager_model)

tests/models/test_modeling_gemma3.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,48 @@ def test_gemma3_text_generation_with_custom_sdpa_8da4w(self):
181181
# model_id = "google/gemma-3-1b-it"
182182
model_id = "unsloth/gemma-3-1b-it"
183183
prompt = "Write a poem about a machine learning."
184+
185+
# ExecuTorch model + custom sdpa + 8da4w linear quantization
186+
kwargs = {"qlinear": True}
187+
model = ExecuTorchModelForCausalLM.from_pretrained(
188+
model_id,
189+
recipe="xnnpack",
190+
attn_implementation="custom_sdpa",
191+
**kwargs,
192+
)
193+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
194+
self.assertIsInstance(model.model, ExecuTorchModule)
195+
184196
tokenizer = AutoTokenizer.from_pretrained(model_id)
185-
kwargs = {"quantize": "8da4w"}
197+
generated_text = model.text_generation(
198+
tokenizer=tokenizer,
199+
prompt=prompt,
200+
max_seq_len=64,
201+
)
202+
logging.info(f"\nGenerated text:\n\t{generated_text}")
203+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
186204

187-
# ExecuTorch model + custom sdpa + float16
205+
# Free memory before loading eager for quality check
206+
del model
207+
del tokenizer
208+
gc.collect()
209+
210+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
211+
212+
@slow
213+
@pytest.mark.run_slow
214+
@pytest.mark.skipif(
215+
parse(torchao.__version__) < parse("0.11.0.dev0"),
216+
reason="Only available on torchao >= 0.11.0.dev0",
217+
)
218+
def test_gemma3_text_generation_with_custom_sdpa_8da4w_8we(self):
219+
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI
220+
# model_id = "google/gemma-3-1b-it"
221+
model_id = "unsloth/gemma-3-1b-it"
222+
prompt = "Write a poem about a machine learning."
223+
224+
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
225+
kwargs = {"qlinear": True, "qembedding": True}
188226
model = ExecuTorchModelForCausalLM.from_pretrained(
189227
model_id,
190228
recipe="xnnpack",
@@ -194,6 +232,7 @@ def test_gemma3_text_generation_with_custom_sdpa_8da4w(self):
194232
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
195233
self.assertIsInstance(model.model, ExecuTorchModule)
196234

235+
tokenizer = AutoTokenizer.from_pretrained(model_id)
197236
generated_text = model.text_generation(
198237
tokenizer=tokenizer,
199238
prompt=prompt,

tests/models/test_modeling_qwen3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,13 @@ def test_qwen3_text_generation_with_custom_sdpa_float16(self):
145145
parse(torchao.__version__) < parse("0.11.0.dev0"),
146146
reason="Only available on torchao >= 0.11.0.dev0",
147147
)
148-
def test_qwen3_text_generation_with_custom_sdpa_8da4w(self):
148+
def test_qwen3_text_generation_with_custom_sdpa_8da4w_8we(self):
149149
model_id = "Qwen/Qwen3-0.6B"
150150
prompt = "Give me a short introduction to large language model."
151151
tokenizer = AutoTokenizer.from_pretrained(model_id)
152152

153-
# ExecuTorch model + custom sdpa
154-
kwargs = {"quantize": "8da4w"}
153+
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
154+
kwargs = {"qlinear": True, "qembedding": True}
155155
model = ExecuTorchModelForCausalLM.from_pretrained(
156156
model_id,
157157
recipe="xnnpack",
@@ -163,7 +163,7 @@ def test_qwen3_text_generation_with_custom_sdpa_8da4w(self):
163163
generated_text = model.text_generation(
164164
tokenizer=tokenizer,
165165
prompt=prompt,
166-
max_seq_len=64,
166+
max_seq_len=128,
167167
)
168168
logging.info(f"\nGenerated text:\n\t{generated_text}")
169169
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

tests/models/test_modeling_smollm.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ def test_smollm_text_generation(self):
8282
def test_smollm_text_generation_with_custom_sdpa(self):
8383
model_id = "HuggingFaceTB/SmolLM2-135M"
8484
prompt = "My favourite condiment is "
85-
max_seq_len = 32
86-
tokenizer = AutoTokenizer.from_pretrained(model_id)
8785

8886
# ExecuTorch model + custom sdpa
8987
model = ExecuTorchModelForCausalLM.from_pretrained(
@@ -94,10 +92,11 @@ def test_smollm_text_generation_with_custom_sdpa(self):
9492
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
9593
self.assertIsInstance(model.model, ExecuTorchModule)
9694

95+
tokenizer = AutoTokenizer.from_pretrained(model_id)
9796
generated_text = model.text_generation(
9897
tokenizer=tokenizer,
9998
prompt=prompt,
100-
max_seq_len=max_seq_len,
99+
max_seq_len=32,
101100
)
102101
logging.info(f"\nGenerated text:\n\t{generated_text}")
103102
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
@@ -118,11 +117,46 @@ def test_smollm_text_generation_with_custom_sdpa(self):
118117
def test_smollm_text_generation_with_custom_sdpa_8da4w(self):
119118
model_id = "HuggingFaceTB/SmolLM2-135M"
120119
prompt = "My favourite condiment is "
121-
max_seq_len = 32
120+
121+
# ExecuTorch model + custom sdpa + 8da4w linear quantization
122+
kwargs = {"qlinear": True}
123+
model = ExecuTorchModelForCausalLM.from_pretrained(
124+
model_id,
125+
recipe="xnnpack",
126+
attn_implementation="custom_sdpa",
127+
**kwargs,
128+
)
129+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
130+
self.assertIsInstance(model.model, ExecuTorchModule)
131+
122132
tokenizer = AutoTokenizer.from_pretrained(model_id)
133+
generated_text = model.text_generation(
134+
tokenizer=tokenizer,
135+
prompt=prompt,
136+
max_seq_len=64,
137+
)
138+
logging.info(f"\nGenerated text:\n\t{generated_text}")
139+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
123140

124-
# ExecuTorch model + custom sdpa
125-
kwargs = {"quantize": "8da4w"}
141+
# Free memory before loading eager for quality check
142+
del model
143+
del tokenizer
144+
gc.collect()
145+
146+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
147+
148+
@slow
149+
@pytest.mark.run_slow
150+
@pytest.mark.skipif(
151+
parse(torchao.__version__) < parse("0.11.0.dev0"),
152+
reason="Only available on torchao >= 0.11.0.dev0",
153+
)
154+
def test_smollm_text_generation_with_custom_sdpa_8da4w_8we(self):
155+
model_id = "HuggingFaceTB/SmolLM2-135M"
156+
prompt = "My favourite condiment is "
157+
158+
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
159+
kwargs = {"qlinear": True, "qembedding": True}
126160
model = ExecuTorchModelForCausalLM.from_pretrained(
127161
model_id,
128162
recipe="xnnpack",
@@ -132,10 +166,11 @@ def test_smollm_text_generation_with_custom_sdpa_8da4w(self):
132166
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
133167
self.assertIsInstance(model.model, ExecuTorchModule)
134168

169+
tokenizer = AutoTokenizer.from_pretrained(model_id)
135170
generated_text = model.text_generation(
136171
tokenizer=tokenizer,
137172
prompt=prompt,
138-
max_seq_len=max_seq_len,
173+
max_seq_len=64,
139174
)
140175
logging.info(f"\nGenerated text:\n\t{generated_text}")
141176
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

0 commit comments

Comments
 (0)