16
16
17
17
import logging
18
18
import os
19
- import warnings
20
19
from pathlib import Path
21
20
from tempfile import TemporaryDirectory
22
21
from typing import List , Optional , Union
23
22
24
23
import torch
24
+ from huggingface_hub import hf_hub_download
25
25
from huggingface_hub .constants import HUGGINGFACE_HUB_CACHE
26
26
from transformers import (
27
27
AutoModelForCausalLM ,
28
28
PretrainedConfig ,
29
29
PreTrainedTokenizer ,
30
30
)
31
31
32
- from executorch .extension .pybindings .portable_lib import (
33
- ExecuTorchModule ,
34
- _load_for_executorch ,
35
- )
32
+ from executorch .extension .pybindings .portable_lib import ExecuTorchModule , _load_for_executorch
36
33
37
34
from ..exporters import TasksManager
38
35
from ..exporters .executorch import main_export
@@ -53,7 +50,7 @@ class ExecuTorchModelForCausalLM(OptimizedModel):
53
50
Attributes:
54
51
auto_model_class (`Type`):
55
52
Associated Transformers class, `AutoModelForCausalLM`.
56
- et_model (`ExecuTorchModule`):
53
+ model (`ExecuTorchModule`):
57
54
The loaded ExecuTorch model.
58
55
use_kv_cache (`bool`):
59
56
Whether key-value caching is enabled. For performance reasons, the exported model is
@@ -74,29 +71,25 @@ class ExecuTorchModelForCausalLM(OptimizedModel):
74
71
75
72
auto_model_class = AutoModelForCausalLM
76
73
77
- def __init__ (
78
- self ,
79
- model : "ExecuTorchModule" ,
80
- config : "PretrainedConfig" ,
81
- ):
74
+ def __init__ (self , model : "ExecuTorchModule" , config : "PretrainedConfig" ):
82
75
super ().__init__ (model , config )
83
- self .et_model = model
84
- metadata = self .et_model .method_names ()
76
+ # self.model = model
77
+ metadata = self .model .method_names ()
85
78
logging .info (f"Load all static methods: { metadata } " )
86
79
if "use_kv_cache" in metadata :
87
- self .use_kv_cache = self .et_model .run_method ("use_kv_cache" )[0 ]
80
+ self .use_kv_cache = self .model .run_method ("use_kv_cache" )[0 ]
88
81
if "get_max_seq_len" in metadata :
89
- self .max_cache_size = self .et_model .run_method ("get_max_seq_len" )[0 ]
82
+ self .max_cache_size = self .model .run_method ("get_max_seq_len" )[0 ]
90
83
if "get_max_batch_size" in metadata :
91
- self .max_batch_size = self .et_model .run_method ("get_max_batch_size" )[0 ]
84
+ self .max_batch_size = self .model .run_method ("get_max_batch_size" )[0 ]
92
85
if "get_dtype" in metadata :
93
- self .dtype = self .et_model .run_method ("get_dtype" )[0 ]
86
+ self .dtype = self .model .run_method ("get_dtype" )[0 ]
94
87
if "get_bos_id" in metadata :
95
- self .bos_token_id = self .et_model .run_method ("get_bos_id" )[0 ]
88
+ self .bos_token_id = self .model .run_method ("get_bos_id" )[0 ]
96
89
if "get_eos_id" in metadata :
97
- self .eos_token_id = self .et_model .run_method ("get_eos_id" )[0 ]
90
+ self .eos_token_id = self .model .run_method ("get_eos_id" )[0 ]
98
91
if "get_vocab_size" in metadata :
99
- self .vocab_size = self .et_model .run_method ("get_vocab_size" )[0 ]
92
+ self .vocab_size = self .model .run_method ("get_vocab_size" )[0 ]
100
93
101
94
def forward (
102
95
self ,
@@ -113,28 +106,33 @@ def forward(
113
106
Returns:
114
107
torch.Tensor: Logits output from the model.
115
108
"""
116
- return self .et_model .forward ((input_ids , cache_position ))[0 ]
117
-
109
+ return self .model .forward ((input_ids , cache_position ))[0 ]
110
+
118
111
@classmethod
119
112
def _from_pretrained (
120
113
cls ,
121
- model_dir_path : Union [str , Path ],
114
+ model_id : Union [str , Path ],
122
115
config : PretrainedConfig ,
116
+ token : Optional [Union [bool , str ]] = None ,
123
117
subfolder : str = "" ,
124
118
revision : Optional [str ] = None ,
125
119
cache_dir : str = HUGGINGFACE_HUB_CACHE ,
126
120
force_download : bool = False ,
127
121
local_files_only : bool = False ,
128
- token : Optional [Union [bool , str ]] = None ,
122
+ ** kwargs ,
123
+ # model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
129
124
) -> "ExecuTorchModelForCausalLM" :
130
125
"""
131
- Load a pre-trained ExecuTorch model from a local directory.
126
+ Load a pre-trained ExecuTorch model from a local directory or hosted on the HF hub .
132
127
133
128
Args:
134
- model_dir_path (`Union[str, Path]`):
129
+ model_id (`Union[str, Path]`):
135
130
Path to the directory containing the ExecuTorch model file (`model.pte`).
136
131
config (`PretrainedConfig`, *optional*):
137
132
Configuration of the pre-trained model.
133
+ token (`Optional[Union[bool,str]]`, defaults to `None`):
134
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
135
+ when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
138
136
subfolder (`str`, defaults to `""`):
139
137
In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can
140
138
specify the folder name here.
@@ -147,43 +145,71 @@ def _from_pretrained(
147
145
cached versions if they exist.
148
146
local_files_only (`Optional[bool]`, defaults to `False`):
149
147
Whether or not to only look at local files (i.e., do not try to download the model).
150
- token (`Optional[Union[bool,str]]`, defaults to `None`):
151
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
152
- when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
153
148
154
149
Returns:
155
150
`ExecuTorchModelForCausalLM`: The initialized ExecuTorch model.
156
151
157
152
"""
158
- full_path = os .path .join (f"{ model_dir_path } " , "model.pte" )
159
- model = _load_for_executorch (full_path )
160
- logging .info (f"Loaded model from { full_path } " )
161
- logging .debug (f"{ model .method_meta ('forward' )} " )
162
- return cls (
163
- model = model ,
164
- config = config ,
153
+ model_path = Path (model_id )
154
+ default_file_name = "model.pte"
155
+
156
+ model_cache_path = cls ._cached_file (
157
+ model_path = model_path ,
158
+ token = token ,
159
+ revision = revision ,
160
+ force_download = force_download ,
161
+ cache_dir = cache_dir ,
162
+ file_name = default_file_name ,
163
+ subfolder = subfolder ,
164
+ local_files_only = local_files_only ,
165
165
)
166
+ model = _load_for_executorch (model_cache_path )
167
+ logging .info (f"Loaded model from { model_cache_path } " )
166
168
167
- def _save_pretrained (self , save_directory ):
168
- """
169
- Saves a model weights into a directory, so that it can be re-loaded using the
170
- [`from_pretrained`] class method.
171
- """
172
- raise NotImplementedError
169
+ return cls (model , config = config )
170
+
171
+ @staticmethod
172
+ def _cached_file (
173
+ model_path : Union [Path , str ],
174
+ token : Optional [Union [bool , str ]] = None ,
175
+ revision : Optional [str ] = None ,
176
+ force_download : bool = False ,
177
+ cache_dir : Optional [str ] = None ,
178
+ file_name : Optional [str ] = None ,
179
+ subfolder : str = "" ,
180
+ local_files_only : bool = False ,
181
+ ):
182
+ model_path = Path (model_path )
183
+ # locates a file in a local folder and repo, downloads and cache it if necessary.
184
+ if model_path .is_dir ():
185
+ model_cache_path = os .path .join (model_path , subfolder , file_name )
186
+ else :
187
+ model_cache_path = hf_hub_download (
188
+ repo_id = model_path .as_posix (),
189
+ filename = file_name ,
190
+ subfolder = subfolder ,
191
+ token = token ,
192
+ revision = revision ,
193
+ cache_dir = cache_dir ,
194
+ force_download = force_download ,
195
+ local_files_only = local_files_only ,
196
+ )
197
+
198
+ return model_cache_path
173
199
174
200
@classmethod
175
201
def _export (
176
202
cls ,
177
203
model_id : str ,
178
204
recipe : str ,
179
205
config : PretrainedConfig ,
206
+ token : Optional [Union [bool , str ]] = None ,
207
+ revision : Optional [str ] = None ,
180
208
cache_dir : str = HUGGINGFACE_HUB_CACHE ,
181
209
trust_remote_code : bool = False ,
182
210
subfolder : str = "" ,
183
- revision : Optional [str ] = None ,
184
211
force_download : bool = False ,
185
212
local_files_only : bool = False ,
186
- token : Optional [Union [bool , str ]] = None ,
187
213
** kwargs ,
188
214
):
189
215
"""
@@ -228,6 +254,7 @@ def _export(
228
254
229
255
save_dir = TemporaryDirectory ()
230
256
save_dir_path = Path (save_dir .name )
257
+
231
258
# Export to ExecuTorch and save the pte file to the temporary directory
232
259
main_export (
233
260
model_name_or_path = model_id ,
@@ -243,17 +270,14 @@ def _export(
243
270
trust_remote_code = trust_remote_code ,
244
271
** kwargs ,
245
272
)
273
+ return cls ._from_pretrained (model_id = save_dir_path , config = config )
246
274
247
- return cls ._from_pretrained (
248
- model_dir_path = save_dir_path ,
249
- config = config ,
250
- subfolder = subfolder ,
251
- revision = revision ,
252
- cache_dir = cache_dir ,
253
- token = token ,
254
- local_files_only = local_files_only ,
255
- force_download = force_download ,
256
- )
275
+ def _save_pretrained (self , save_directory ):
276
+ """
277
+ Saves a model weights into a directory, so that it can be re-loaded using the
278
+ [`from_pretrained`] class method.
279
+ """
280
+ raise NotImplementedError
257
281
258
282
def generate (
259
283
self ,
@@ -367,4 +391,4 @@ def text_generation(
367
391
def _from_transformers (cls , * args , ** kwargs ):
368
392
# TODO : add warning when from_pretrained_method is set to cls._export instead of cls._from_transformers when export=True
369
393
# logger.warning("The method `_from_transformers` is deprecated, please use `_export` instead")
370
- return cls ._export (* args , ** kwargs )
394
+ return cls ._export (* args , ** kwargs )
0 commit comments