Skip to content

Commit 04a3f63

Browse files
committed
Add CoreML recipe
1 parent de493ba commit 04a3f63

File tree

2 files changed

+157
-1
lines changed

2 files changed

+157
-1
lines changed

optimum/exporters/executorch/__main__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def main_export(
130130
kwargs["force_download"] = force_download
131131
kwargs["config"] = config
132132

133+
recipe_kwargs = kwargs.pop("recipe_kwargs", {})
134+
133135
model = task_func(model_name_or_path, **kwargs)
134136

135137
if not os.path.exists(output_dir):
@@ -140,7 +142,7 @@ def main_export(
140142
task=task,
141143
recipe=recipe,
142144
output_dir=output_dir,
143-
**kwargs,
145+
**recipe_kwargs,
144146
)
145147

146148

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from typing import Dict, Union
17+
18+
from packaging.version import parse
19+
from tabulate import tabulate
20+
from torch.export import ExportedProgram
21+
import coremltools as ct
22+
import torch
23+
24+
from executorch import version as executorch_version
25+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
26+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
27+
28+
from executorch.devtools.backend_debug import get_delegation_info
29+
from executorch.exir import (
30+
EdgeCompileConfig,
31+
ExecutorchBackendConfig,
32+
ExecutorchProgram,
33+
to_edge_transform_and_lower,
34+
)
35+
from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass
36+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
37+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
38+
from ..integrations import (
39+
CausalLMExportableModule,
40+
MaskedLMExportableModule,
41+
Seq2SeqLMExportableModule,
42+
)
43+
from ..recipe_registry import register_recipe
44+
45+
def get_quantization_config():
46+
quantization_config = ct.optimize.torch.quantization.LinearQuantizerConfig.from_dict(
47+
{
48+
"global_config": {
49+
"quantization_scheme": ct.optimize.torch.quantization.QuantizationScheme.symmetric,
50+
"activation_dtype": torch.quint8,
51+
"weight_dtype": torch.qint8,
52+
"weight_per_channel": True,
53+
}
54+
}
55+
)
56+
return quantization_config
57+
58+
def quantize_program(ep):
59+
quantizer = CoreMLQuantizer(get_quantization_config())
60+
gm = ep.module()
61+
62+
args, kwargs = ep.example_inputs
63+
prepared_model = prepare_pt2e(gm, quantizer)
64+
prepared_model(*args, **kwargs)
65+
converted_model = convert_pt2e(prepared_model)
66+
return torch.export.export(converted_model, args, kwargs)
67+
68+
69+
@register_recipe("coreml")
70+
def export_to_executorch_with_coreml(
71+
model: Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule],
72+
**kwargs,
73+
):
74+
"""
75+
Export a PyTorch model to ExecuTorch w/ delegation to CoreML backend.
76+
77+
This function also write metadata required by the ExecuTorch runtime to the model.
78+
79+
Args:
80+
model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule]):
81+
The PyTorch model to be exported to ExecuTorch.
82+
**kwargs:
83+
Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs.
84+
85+
Returns:
86+
Dict[str, ExecutorchProgram]:
87+
A map of exported and optimized program for ExecuTorch.
88+
For encoder-decoder models or multimodal models, it may generate multiple programs.
89+
"""
90+
91+
def _lower_to_executorch(
92+
exported_programs: Dict[str, ExportedProgram],
93+
metadata=None,
94+
**kwargs,
95+
) -> Dict[str, ExecutorchProgram]:
96+
97+
minimum_deployment_target = kwargs.get("minimum_ios_deployment_target", "15")
98+
minimum_deployment_target = {
99+
"15": ct.target.iOS15,
100+
"16": ct.target.iOS16,
101+
"17": ct.target.iOS17,
102+
"18": ct.target.iOS18,
103+
}[minimum_deployment_target]
104+
105+
compute_precision = kwargs.get("compute_precision", "fp16")
106+
compute_precision = {
107+
"fp16": ct.precision.FLOAT16,
108+
"fp32": ct.precision.FLOAT32,
109+
}[compute_precision]
110+
111+
model_type = kwargs.get("model_type", "model")
112+
model_type = {
113+
"model": CoreMLBackend.MODEL_TYPE.MODEL,
114+
"modelc": CoreMLBackend.MODEL_TYPE.COMPILED_MODEL,
115+
}[model_type]
116+
take_over_mutable_buffer = kwargs.get("take_over_mutable_buffer", True)
117+
quantize = kwargs.get("quantize", False)
118+
119+
et_progs = {}
120+
backend_config_dict = {}
121+
for pte_name, exported_program in exported_programs.items():
122+
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
123+
if quantize:
124+
exported_program = quantize_program(exported_program)
125+
et_progs[pte_name] = to_edge_transform_and_lower(
126+
exported_program,
127+
partitioner=[CoreMLPartitioner(
128+
compile_specs=CoreMLBackend.generate_compile_specs(
129+
minimum_deployment_target=minimum_deployment_target,
130+
compute_precision=compute_precision,
131+
model_type=model_type,
132+
),
133+
take_over_mutable_buffer=take_over_mutable_buffer, # Fails when set to true
134+
)],
135+
compile_config=EdgeCompileConfig(
136+
_check_ir_validity=False,
137+
_skip_dim_order=True,
138+
),
139+
constant_methods=metadata,
140+
).to_executorch(
141+
config=ExecutorchBackendConfig(**backend_config_dict),
142+
)
143+
logging.debug(
144+
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
145+
)
146+
delegation_info = get_delegation_info(et_progs[pte_name].exported_program().graph_module)
147+
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
148+
logging.debug(
149+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
150+
)
151+
return et_progs
152+
153+
exported_progs = model.export()
154+
return _lower_to_executorch(exported_progs, model.metadata, **kwargs)

0 commit comments

Comments
 (0)