16
16
import os
17
17
import tempfile
18
18
import unittest
19
+ from pathlib import Path
20
+ from tempfile import TemporaryDirectory
19
21
20
22
from executorch .extension .pybindings .portable_lib import ExecuTorchModule
23
+ from huggingface_hub import HfApi
21
24
22
25
from optimum .executorch import ExecuTorchModelForCausalLM
23
26
from optimum .exporters .executorch import main_export
27
+ from optimum .utils .file_utils import _find_files_matching_pattern
24
28
25
29
26
30
class ExecuTorchModelIntegrationTest (unittest .TestCase ):
@@ -30,7 +34,7 @@ def __init__(self, *args, **kwargs):
30
34
def test_load_model_from_hub (self ):
31
35
model_id = "optimum-internal-testing/tiny-random-llama"
32
36
33
- model = ExecuTorchModelForCausalLM .from_pretrained (model_id , export = True , recipe = "xnnpack" )
37
+ model = ExecuTorchModelForCausalLM .from_pretrained (model_id , recipe = "xnnpack" )
34
38
self .assertIsInstance (model , ExecuTorchModelForCausalLM )
35
39
self .assertIsInstance (model .model , ExecuTorchModule )
36
40
@@ -58,6 +62,26 @@ def test_load_model_from_local_path(self):
58
62
self .assertTrue (os .path .exists (f"{ tempdir } /model.pte" ))
59
63
60
64
# Load the exported model from a local dir
61
- model = ExecuTorchModelForCausalLM .from_pretrained (tempdir , export = False )
65
+ model = ExecuTorchModelForCausalLM .from_pretrained (tempdir )
62
66
self .assertIsInstance (model , ExecuTorchModelForCausalLM )
63
67
self .assertIsInstance (model .model , ExecuTorchModule )
68
+
69
+ def test_find_files_matching_pattern (self ):
70
+ model_id = "optimum-internal-testing/tiny-random-llama"
71
+ pattern = r"(.*).pte$"
72
+
73
+ # hub model
74
+ for revision in ("main" , "executorch" ):
75
+ pte_files = _find_files_matching_pattern (model_id , pattern = pattern , revision = revision )
76
+ self .assertTrue (len (pte_files ) == 0 if revision == "main" else len (pte_files ) > 0 )
77
+
78
+ # local model
79
+ api = HfApi ()
80
+ with TemporaryDirectory () as tmpdirname :
81
+ for revision in ("main" , "executorch" ):
82
+ local_dir = Path (tmpdirname ) / revision
83
+ api .snapshot_download (repo_id = model_id , local_dir = local_dir , revision = revision )
84
+ pte_files = _find_files_matching_pattern (
85
+ local_dir , pattern = pattern , revision = revision , subfolder = revision
86
+ )
87
+ self .assertTrue (len (pte_files ) == 0 if revision == "main" else len (pte_files ) > 0 )
0 commit comments