Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 976ccfd

Browse files
author
Guang Yang
committedMay 20, 2025·
Support lowering quantized checkpoint from Hub
1 parent da80c9e commit 976ccfd

File tree

5 files changed

+146
-10
lines changed

5 files changed

+146
-10
lines changed
 

‎optimum/executorch/modeling.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
9292
f"This attribute is used to identify the corresponding AutoModel class."
9393
)
9494

95-
for key, value in models.items():
96-
setattr(self, key, value)
95+
if len(models) == 1:
96+
# For single PTE, always set the attr to "model"
97+
setattr(self, "model", next(iter(models.values())))
98+
else:
99+
for key, value in models.items():
100+
setattr(self, key, value)
97101

98102
self.stats = Stats()
99103

@@ -570,8 +574,8 @@ class ExecuTorchModelForCausalLM(ExecuTorchModelBase):
570574
Data type of the model parameters.
571575
bos_token_id (`int`):
572576
Beginning-of-sequence token ID.
573-
eos_token_id (`int`):
574-
End-of-sequence token ID.
577+
eos_token_ids (`List[int]`):
578+
End-of-sequence token IDs.
575579
vocab_size (`int`):
576580
Size of the model vocabulary.
577581
"""
@@ -594,8 +598,10 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
594598
self.dtype = self.model.run_method("get_dtype")[0]
595599
if "get_bos_id" in metadata:
596600
self.bos_token_id = self.model.run_method("get_bos_id")[0]
597-
if "get_eos_id" in metadata:
598-
self.eos_token_id = self.model.run_method("get_eos_id")[0]
601+
for key in ("get_eos_id", "get_eos_ids"):
602+
if key in metadata:
603+
self.eos_token_ids = self.model.run_method(key)
604+
break
599605
if "get_vocab_size" in metadata:
600606
self.vocab_size = self.model.run_method("get_vocab_size")[0]
601607
if "use_sdpa_with_kv_cache" in metadata:
@@ -694,7 +700,7 @@ def generate(
694700
next_token = torch.argmax(logits, dim=-1).item()
695701
generated_tokens.append(next_token)
696702

697-
if next_token == self.eos_token_id:
703+
if next_token in self.eos_token_ids:
698704
break
699705

700706
self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens))
@@ -730,9 +736,9 @@ def text_generation(
730736
raise ValueError(
731737
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
732738
)
733-
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id:
739+
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id not in self.eos_token_ids:
734740
raise ValueError(
735-
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}."
741+
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must match with the model's eos_token_ids={self.eos_token_ids}."
736742
)
737743

738744
# Reset stats for a new generation
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from executorch.exir.dialects._ops import ops as exir_ops
3+
from executorch.exir.pass_base import ExportPass, PassResult
4+
5+
6+
class RemovePaddingIdxEmbeddingPass(ExportPass):
7+
"""
8+
An ExportPass that removes the `padding_idx` keyword argument
9+
from all aten.embedding.default operator calls.
10+
"""
11+
12+
def __init__(self) -> None:
13+
super().__init__()
14+
15+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
16+
for node in graph_module.graph.nodes:
17+
if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default:
18+
# node.args[2] is the padding_idx
19+
if len(node.args) == 3:
20+
node.args = (node.args[0], node.args[1])
21+
graph_module.recompile()
22+
return PassResult(graph_module, True)

‎optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ExecutorchProgram,
2929
to_edge_transform_and_lower,
3030
)
31+
from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass
3132

3233
from ..integrations import (
3334
CausalLMExportableModule,
@@ -76,9 +77,11 @@ def _lower_to_executorch(
7677
exported_program,
7778
partitioner=[XnnpackPartitioner()],
7879
compile_config=EdgeCompileConfig(
80+
_check_ir_validity=False,
7981
_skip_dim_order=True,
8082
),
8183
constant_methods=metadata,
84+
transform_passes=[RemovePaddingIdxEmbeddingPass()],
8285
).to_executorch(
8386
config=ExecutorchBackendConfig(**backend_config_dict),
8487
)

‎optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
7777
),
7878
)
7979

80+
for param in eager_model.parameters():
81+
# Must disable gradient for quantized checkpoint
82+
if isinstance(param, torchao.utils.TorchAOBaseTensor):
83+
param.requires_grad = False
84+
8085
# TODO: Move quantization recipe out for better composability.
8186
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
8287
qlinear_config = kwargs.get("qlinear", None)

‎tests/models/test_modeling_phi4.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515

1616
import gc
1717
import logging
18+
import os
1819
import unittest
1920

2021
import pytest
22+
import torchao
2123
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24+
from packaging.version import parse
2225
from transformers import AutoConfig, AutoTokenizer
2326
from transformers.testing_utils import slow
2427

@@ -27,13 +30,21 @@
2730
from ..utils import check_causal_lm_output_quality
2831

2932

30-
@pytest.mark.skip(reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM")
33+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
34+
35+
is_ci = os.environ.get("GITHUB_ACTIONS") == "true"
36+
37+
3138
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3239
def __init__(self, *args, **kwargs):
3340
super().__init__(*args, **kwargs)
3441

3542
@slow
3643
@pytest.mark.run_slow
44+
@pytest.mark.skipif(
45+
is_ci,
46+
reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM",
47+
)
3748
def test_phi4_text_generation(self):
3849
model_id = "microsoft/Phi-4-mini-instruct"
3950
config = AutoConfig.from_pretrained(model_id)
@@ -61,3 +72,92 @@ def test_phi4_text_generation(self):
6172
gc.collect()
6273

6374
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
75+
76+
@slow
77+
@pytest.mark.run_slow
78+
@pytest.mark.skipif(
79+
parse(torchao.__version__) < parse("0.11.0.dev0"),
80+
reason="Only available on torchao >= 0.11.0.dev0",
81+
)
82+
def test_phi4_text_generation_with_quantized_pte_from_hub(self):
83+
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
84+
config = AutoConfig.from_pretrained(model_id)
85+
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
86+
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
87+
# that function to avoid the data-dependent control flow.
88+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
89+
config.rope_scaling["type"] = "default"
90+
model = ExecuTorchModelForCausalLM.from_pretrained(
91+
model_id, recipe="xnnpack", config=config, file_name="phi4-mini-8da4w.pte"
92+
)
93+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
94+
self.assertIsInstance(model.model, ExecuTorchModule)
95+
96+
tokenizer = AutoTokenizer.from_pretrained(model_id)
97+
generated_text = model.text_generation(
98+
tokenizer=tokenizer,
99+
prompt="My favourite condiment is ",
100+
max_seq_len=64,
101+
)
102+
logging.info(f"\nGenerated text:\n\t{generated_text}")
103+
104+
if not is_ci:
105+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
106+
107+
# Free memory before loading eager for quality check
108+
del model
109+
del tokenizer
110+
gc.collect()
111+
112+
self.assertTrue(
113+
check_causal_lm_output_quality(
114+
"microsoft/Phi-4-mini-instruct",
115+
generated_tokens,
116+
)
117+
)
118+
119+
@slow
120+
@pytest.mark.run_slow
121+
@pytest.mark.skipif(
122+
parse(torchao.__version__) < parse("0.11.0.dev0"),
123+
reason="Only available on torchao >= 0.11.0.dev0",
124+
)
125+
def test_phi4_text_generation_with_quantized_ckp(self):
126+
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
127+
config = AutoConfig.from_pretrained(model_id)
128+
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
129+
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
130+
# that function to avoid the data-dependent control flow.
131+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
132+
config.rope_scaling["type"] = "default"
133+
model = ExecuTorchModelForCausalLM.from_pretrained(
134+
model_id,
135+
recipe="xnnpack",
136+
config=config,
137+
export=True,
138+
)
139+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
140+
self.assertIsInstance(model.model, ExecuTorchModule)
141+
142+
tokenizer = AutoTokenizer.from_pretrained(model_id)
143+
generated_text = model.text_generation(
144+
tokenizer=tokenizer,
145+
prompt="My favourite condiment is ",
146+
max_seq_len=64,
147+
)
148+
logging.info(f"\nGenerated text:\n\t{generated_text}")
149+
150+
if not is_ci:
151+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
152+
153+
# Free memory before loading eager for quality check
154+
del model
155+
del tokenizer
156+
gc.collect()
157+
158+
self.assertTrue(
159+
check_causal_lm_output_quality(
160+
"microsoft/Phi-4-mini-instruct",
161+
generated_tokens,
162+
)
163+
)

0 commit comments

Comments
 (0)
Please sign in to comment.