Skip to content

Commit 6f4685b

Browse files
committed
add test
1 parent f33d6f5 commit 6f4685b

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

tests/models/test_modeling.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
import os
1717
import tempfile
1818
import unittest
19+
from pathlib import Path
20+
from tempfile import TemporaryDirectory
1921

2022
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
23+
from huggingface_hub import HfApi
2124

2225
from optimum.executorch import ExecuTorchModelForCausalLM
2326
from optimum.exporters.executorch import main_export
27+
from optimum.utils.file_utils import _find_files_matching_pattern
2428

2529

2630
class ExecuTorchModelIntegrationTest(unittest.TestCase):
@@ -30,7 +34,7 @@ def __init__(self, *args, **kwargs):
3034
def test_load_model_from_hub(self):
3135
model_id = "optimum-internal-testing/tiny-random-llama"
3236

33-
model = ExecuTorchModelForCausalLM.from_pretrained(model_id, export=True, recipe="xnnpack")
37+
model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="xnnpack")
3438
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
3539
self.assertIsInstance(model.model, ExecuTorchModule)
3640

@@ -58,6 +62,26 @@ def test_load_model_from_local_path(self):
5862
self.assertTrue(os.path.exists(f"{tempdir}/model.pte"))
5963

6064
# Load the exported model from a local dir
61-
model = ExecuTorchModelForCausalLM.from_pretrained(tempdir, export=False)
65+
model = ExecuTorchModelForCausalLM.from_pretrained(tempdir)
6266
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
6367
self.assertIsInstance(model.model, ExecuTorchModule)
68+
69+
def test_find_files_matching_pattern(self):
70+
model_id = "optimum-internal-testing/tiny-random-llama"
71+
pattern = r"(.*).pte$"
72+
73+
# hub model
74+
for revision in ("main", "executorch"):
75+
pte_files = _find_files_matching_pattern(model_id, pattern=pattern, revision=revision)
76+
self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0)
77+
78+
# local model
79+
api = HfApi()
80+
with TemporaryDirectory() as tmpdirname:
81+
for revision in ("main", "executorch"):
82+
local_dir = Path(tmpdirname) / revision
83+
api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision)
84+
pte_files = _find_files_matching_pattern(
85+
local_dir, pattern=pattern, revision=revision, subfolder=revision
86+
)
87+
self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0)

0 commit comments

Comments
 (0)