Skip to content

Commit dd505cc

Browse files
authored
Remove the task argument for ExecuTorchModelForXxx (#12)
* remove the need to provide the task argument * update tests
1 parent b143d40 commit dd505cc

File tree

7 files changed

+6
-18
lines changed

7 files changed

+6
-18
lines changed

optimum/executorch/modeling.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_load_for_executorch,
3535
)
3636

37+
from ..exporters import TasksManager
3738
from ..exporters.executorch import main_export
3839
from ..modeling_base import OptimizedModel
3940

@@ -119,7 +120,6 @@ def from_pretrained(
119120
cls,
120121
model_name_or_path: Union[str, Path],
121122
export: bool = True,
122-
task: str = "",
123123
recipe: str = "",
124124
config: "PretrainedConfig" = None,
125125
subfolder: str = "",
@@ -140,8 +140,6 @@ def from_pretrained(
140140
export (`bool`, *optional*, defaults to `True`):
141141
If `True`, the model will be exported from eager to ExecuTorch after fetched from huggingface.co. `model_name_or_path` must be a valid model ID on huggingface.co.
142142
If `False`, the previously exported ExecuTorch model will be loaded from a local path. `model_name_or_path` must be a valid local directory where a `model.pte` is stored.
143-
task (`str`, defaults to `""`):
144-
The task to export the model for, e.g. "text-generation". It is required to specify a task when `export` is `True`.
145143
recipe (`str`, defaults to `""`):
146144
The recipe to use to do the export, e.g. "xnnpack". It is required to specify a task when `export` is `True`.
147145
config (`PretrainedConfig`, *optional*):
@@ -180,13 +178,10 @@ def from_pretrained(
180178

181179
if export:
182180
# Fetch the model from huggingface.co and export it to ExecuTorch
183-
if task == "":
184-
raise ValueError("Please specify a task to export the model for.")
185181
if recipe == "":
186182
raise ValueError("Please specify a recipe to export the model for.")
187183
return cls._export(
188184
model_id=model_name_or_path,
189-
task=task,
190185
recipe=recipe,
191186
config=config,
192187
**kwargs,
@@ -261,7 +256,6 @@ def _save_pretrained(self, save_directory):
261256
def _export(
262257
cls,
263258
model_id: str,
264-
task: str,
265259
recipe: str,
266260
config: PretrainedConfig,
267261
cache_dir: str = HUGGINGFACE_HUB_CACHE,
@@ -280,8 +274,6 @@ def _export(
280274
Args:
281275
model_id (`str`):
282276
Model ID on huggingface.co, for example: `model_name_or_path="meta-llama/Llama-3.2-1B"`.
283-
task (`str`):
284-
The task to export the model for, e.g. "text-generation".
285277
recipe (`str`):
286278
The recipe to use to do the export, e.g. "xnnpack".
287279
config (`PretrainedConfig`, *optional*):
@@ -314,6 +306,10 @@ def _export(
314306
`ExecuTorchModelForCausalLM`: The loaded and exported ExecuTorch model.
315307
316308
"""
309+
task = kwargs.pop("task", None)
310+
if task is not None:
311+
logger.warning(f"task was provided and set to {task} but not used, will be ignored")
312+
317313
if use_auth_token is not None:
318314
warnings.warn(
319315
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
@@ -325,12 +321,11 @@ def _export(
325321

326322
save_dir = TemporaryDirectory()
327323
save_dir_path = Path(save_dir.name)
328-
329324
# Export to ExecuTorch and save the pte file to the temporary directory
330325
main_export(
331326
model_name_or_path=model_id,
332327
output_dir=save_dir_path,
333-
task=task,
328+
task=TasksManager.infer_task_from_model(cls.auto_model_class),
334329
recipe=recipe,
335330
subfolder=subfolder,
336331
revision=revision,

tests/models/test_modeling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def test_load_model_from_hub(self):
3434
model = ExecuTorchModelForCausalLM.from_pretrained(
3535
model_name_or_path="NousResearch/Llama-3.2-1B",
3636
export=True,
37-
task="text-generation",
3837
recipe="xnnpack",
3938
)
4039
self.assertIsInstance(model, ExecuTorchModelForCausalLM)

tests/models/test_modeling_gemma.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def test_gemma_text_generation_with_xnnpack(self):
3636
model = ExecuTorchModelForCausalLM.from_pretrained(
3737
model_name_or_path=model_id,
3838
export=True,
39-
task="text-generation",
4039
recipe="xnnpack",
4140
)
4241
self.assertIsInstance(model, ExecuTorchModelForCausalLM)

tests/models/test_modeling_gemma2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def test_gemma2_text_generation_with_xnnpack(self):
3636
model = ExecuTorchModelForCausalLM.from_pretrained(
3737
model_name_or_path=model_id,
3838
export=True,
39-
task="text-generation",
4039
recipe="xnnpack",
4140
)
4241
self.assertIsInstance(model, ExecuTorchModelForCausalLM)

tests/models/test_modeling_llama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def test_llama3_2_1b_text_generation_with_xnnpack(self):
3636
model = ExecuTorchModelForCausalLM.from_pretrained(
3737
model_name_or_path=model_id,
3838
export=True,
39-
task="text-generation",
4039
recipe="xnnpack",
4140
)
4241
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
@@ -61,7 +60,6 @@ def test_llama3_2_3b_text_generation_with_xnnpack(self):
6160
model = ExecuTorchModelForCausalLM.from_pretrained(
6261
model_name_or_path=model_id,
6362
export=True,
64-
task="text-generation",
6563
recipe="xnnpack",
6664
)
6765
self.assertIsInstance(model, ExecuTorchModelForCausalLM)

tests/models/test_modeling_olmo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def test_olmo_text_generation_with_xnnpack(self):
3434
model = ExecuTorchModelForCausalLM.from_pretrained(
3535
model_name_or_path=model_id,
3636
export=True,
37-
task="text-generation",
3837
recipe="xnnpack",
3938
)
4039
self.assertIsInstance(model, ExecuTorchModelForCausalLM)

tests/models/test_modeling_qwen2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def test_qwen2_5_text_generation_with_xnnpack(self):
3434
model = ExecuTorchModelForCausalLM.from_pretrained(
3535
model_name_or_path=model_id,
3636
export=True,
37-
task="text-generation",
3837
recipe="xnnpack",
3938
)
4039
self.assertIsInstance(model, ExecuTorchModelForCausalLM)

0 commit comments

Comments
 (0)