Skip to content

Commit e1d0eb4

Browse files
committed
Enable loading model from the HF hub
1 parent d600659 commit e1d0eb4

File tree

1 file changed

+79
-55
lines changed

1 file changed

+79
-55
lines changed

optimum/executorch/modeling.py

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,20 @@
1616

1717
import logging
1818
import os
19-
import warnings
2019
from pathlib import Path
2120
from tempfile import TemporaryDirectory
2221
from typing import List, Optional, Union
2322

2423
import torch
24+
from huggingface_hub import hf_hub_download
2525
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
2626
from transformers import (
2727
AutoModelForCausalLM,
2828
PretrainedConfig,
2929
PreTrainedTokenizer,
3030
)
3131

32-
from executorch.extension.pybindings.portable_lib import (
33-
ExecuTorchModule,
34-
_load_for_executorch,
35-
)
32+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
3633

3734
from ..exporters import TasksManager
3835
from ..exporters.executorch import main_export
@@ -53,7 +50,7 @@ class ExecuTorchModelForCausalLM(OptimizedModel):
5350
Attributes:
5451
auto_model_class (`Type`):
5552
Associated Transformers class, `AutoModelForCausalLM`.
56-
et_model (`ExecuTorchModule`):
53+
model (`ExecuTorchModule`):
5754
The loaded ExecuTorch model.
5855
use_kv_cache (`bool`):
5956
Whether key-value caching is enabled. For performance reasons, the exported model is
@@ -74,29 +71,25 @@ class ExecuTorchModelForCausalLM(OptimizedModel):
7471

7572
auto_model_class = AutoModelForCausalLM
7673

77-
def __init__(
78-
self,
79-
model: "ExecuTorchModule",
80-
config: "PretrainedConfig",
81-
):
74+
def __init__(self, model: "ExecuTorchModule", config: "PretrainedConfig"):
8275
super().__init__(model, config)
83-
self.et_model = model
84-
metadata = self.et_model.method_names()
76+
# self.model = model
77+
metadata = self.model.method_names()
8578
logging.info(f"Load all static methods: {metadata}")
8679
if "use_kv_cache" in metadata:
87-
self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0]
80+
self.use_kv_cache = self.model.run_method("use_kv_cache")[0]
8881
if "get_max_seq_len" in metadata:
89-
self.max_cache_size = self.et_model.run_method("get_max_seq_len")[0]
82+
self.max_cache_size = self.model.run_method("get_max_seq_len")[0]
9083
if "get_max_batch_size" in metadata:
91-
self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0]
84+
self.max_batch_size = self.model.run_method("get_max_batch_size")[0]
9285
if "get_dtype" in metadata:
93-
self.dtype = self.et_model.run_method("get_dtype")[0]
86+
self.dtype = self.model.run_method("get_dtype")[0]
9487
if "get_bos_id" in metadata:
95-
self.bos_token_id = self.et_model.run_method("get_bos_id")[0]
88+
self.bos_token_id = self.model.run_method("get_bos_id")[0]
9689
if "get_eos_id" in metadata:
97-
self.eos_token_id = self.et_model.run_method("get_eos_id")[0]
90+
self.eos_token_id = self.model.run_method("get_eos_id")[0]
9891
if "get_vocab_size" in metadata:
99-
self.vocab_size = self.et_model.run_method("get_vocab_size")[0]
92+
self.vocab_size = self.model.run_method("get_vocab_size")[0]
10093

10194
def forward(
10295
self,
@@ -113,28 +106,33 @@ def forward(
113106
Returns:
114107
torch.Tensor: Logits output from the model.
115108
"""
116-
return self.et_model.forward((input_ids, cache_position))[0]
117-
109+
return self.model.forward((input_ids, cache_position))[0]
110+
118111
@classmethod
119112
def _from_pretrained(
120113
cls,
121-
model_dir_path: Union[str, Path],
114+
model_id: Union[str, Path],
122115
config: PretrainedConfig,
116+
token: Optional[Union[bool, str]] = None,
123117
subfolder: str = "",
124118
revision: Optional[str] = None,
125119
cache_dir: str = HUGGINGFACE_HUB_CACHE,
126120
force_download: bool = False,
127121
local_files_only: bool = False,
128-
token: Optional[Union[bool, str]] = None,
122+
**kwargs,
123+
# model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
129124
) -> "ExecuTorchModelForCausalLM":
130125
"""
131-
Load a pre-trained ExecuTorch model from a local directory.
126+
Load a pre-trained ExecuTorch model from a local directory or hosted on the HF hub.
132127
133128
Args:
134-
model_dir_path (`Union[str, Path]`):
129+
model_id (`Union[str, Path]`):
135130
Path to the directory containing the ExecuTorch model file (`model.pte`).
136131
config (`PretrainedConfig`, *optional*):
137132
Configuration of the pre-trained model.
133+
token (`Optional[Union[bool,str]]`, defaults to `None`):
134+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
135+
when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
138136
subfolder (`str`, defaults to `""`):
139137
In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can
140138
specify the folder name here.
@@ -147,43 +145,71 @@ def _from_pretrained(
147145
cached versions if they exist.
148146
local_files_only (`Optional[bool]`, defaults to `False`):
149147
Whether or not to only look at local files (i.e., do not try to download the model).
150-
token (`Optional[Union[bool,str]]`, defaults to `None`):
151-
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
152-
when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
153148
154149
Returns:
155150
`ExecuTorchModelForCausalLM`: The initialized ExecuTorch model.
156151
157152
"""
158-
full_path = os.path.join(f"{model_dir_path}", "model.pte")
159-
model = _load_for_executorch(full_path)
160-
logging.info(f"Loaded model from {full_path}")
161-
logging.debug(f"{model.method_meta('forward')}")
162-
return cls(
163-
model=model,
164-
config=config,
153+
model_path = Path(model_id)
154+
default_file_name = "model.pte"
155+
156+
model_cache_path = cls._cached_file(
157+
model_path=model_path,
158+
token=token,
159+
revision=revision,
160+
force_download=force_download,
161+
cache_dir=cache_dir,
162+
file_name=default_file_name,
163+
subfolder=subfolder,
164+
local_files_only=local_files_only,
165165
)
166+
model = _load_for_executorch(model_cache_path)
167+
logging.info(f"Loaded model from {model_cache_path}")
166168

167-
def _save_pretrained(self, save_directory):
168-
"""
169-
Saves a model weights into a directory, so that it can be re-loaded using the
170-
[`from_pretrained`] class method.
171-
"""
172-
raise NotImplementedError
169+
return cls(model, config=config)
170+
171+
@staticmethod
172+
def _cached_file(
173+
model_path: Union[Path, str],
174+
token: Optional[Union[bool, str]] = None,
175+
revision: Optional[str] = None,
176+
force_download: bool = False,
177+
cache_dir: Optional[str] = None,
178+
file_name: Optional[str] = None,
179+
subfolder: str = "",
180+
local_files_only: bool = False,
181+
):
182+
model_path = Path(model_path)
183+
# locates a file in a local folder and repo, downloads and cache it if necessary.
184+
if model_path.is_dir():
185+
model_cache_path = os.path.join(model_path, subfolder, file_name)
186+
else:
187+
model_cache_path = hf_hub_download(
188+
repo_id=model_path.as_posix(),
189+
filename=file_name,
190+
subfolder=subfolder,
191+
token=token,
192+
revision=revision,
193+
cache_dir=cache_dir,
194+
force_download=force_download,
195+
local_files_only=local_files_only,
196+
)
197+
198+
return model_cache_path
173199

174200
@classmethod
175201
def _export(
176202
cls,
177203
model_id: str,
178204
recipe: str,
179205
config: PretrainedConfig,
206+
token: Optional[Union[bool, str]] = None,
207+
revision: Optional[str] = None,
180208
cache_dir: str = HUGGINGFACE_HUB_CACHE,
181209
trust_remote_code: bool = False,
182210
subfolder: str = "",
183-
revision: Optional[str] = None,
184211
force_download: bool = False,
185212
local_files_only: bool = False,
186-
token: Optional[Union[bool, str]] = None,
187213
**kwargs,
188214
):
189215
"""
@@ -228,6 +254,7 @@ def _export(
228254

229255
save_dir = TemporaryDirectory()
230256
save_dir_path = Path(save_dir.name)
257+
231258
# Export to ExecuTorch and save the pte file to the temporary directory
232259
main_export(
233260
model_name_or_path=model_id,
@@ -243,17 +270,14 @@ def _export(
243270
trust_remote_code=trust_remote_code,
244271
**kwargs,
245272
)
273+
return cls._from_pretrained(model_id=save_dir_path, config=config)
246274

247-
return cls._from_pretrained(
248-
model_dir_path=save_dir_path,
249-
config=config,
250-
subfolder=subfolder,
251-
revision=revision,
252-
cache_dir=cache_dir,
253-
token=token,
254-
local_files_only=local_files_only,
255-
force_download=force_download,
256-
)
275+
def _save_pretrained(self, save_directory):
276+
"""
277+
Saves a model weights into a directory, so that it can be re-loaded using the
278+
[`from_pretrained`] class method.
279+
"""
280+
raise NotImplementedError
257281

258282
def generate(
259283
self,
@@ -367,4 +391,4 @@ def text_generation(
367391
def _from_transformers(cls, *args, **kwargs):
368392
# TODO : add warning when from_pretrained_method is set to cls._export instead of cls._from_transformers when export=True
369393
# logger.warning("The method `_from_transformers` is deprecated, please use `_export` instead")
370-
return cls._export(*args, **kwargs)
394+
return cls._export(*args, **kwargs)

0 commit comments

Comments
 (0)