Skip to content

Commit 4b37326

Browse files
guangy10Guang Yang
andauthored
Fix for Qwen3 embedding (#85)
Co-authored-by: Guang Yang <[email protected]>
1 parent 4c3b18f commit 4b37326

File tree

4 files changed

+124
-4
lines changed

4 files changed

+124
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ We currently support a wide range of popular transformer models, including encod
150150
- [Gemma3](https://huggingface.co/google/gemma-3-1b-it): `Gemma-3-1b` and its variants *(requires `transformers >= 4.52.0`)*
151151
- [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B): `Llama-3.2-1B` and its variants
152152
- [Qwen2](https://huggingface.co/Qwen/Qwen2.5-0.5B): `Qwen2.5-0.5B` and its variants
153-
- [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B): `Qwen3-0.6B` and its variants
153+
- [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B): `Qwen3-0.6B`, `Qwen3-Embedding-0.6B` and other variants
154154
- [Olmo](https://huggingface.co/allenai/OLMo-1B-hf): `OLMo-1B-hf` and its variants
155155
- [Phi4](https://huggingface.co/microsoft/Phi-4-mini-instruct): `Phi-4-mini-instruct` and its variants
156156
- [Smollm](https://huggingface.co/HuggingFaceTB/SmolLM2-135M): 🤗 `SmolLM2-135M` and its variants

optimum/executorch/modeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
from ..exporters import TasksManager
4343
from ..exporters.executorch import main_export
44+
from ..exporters.executorch.utils import verify_eos_tokens_in_tokenizer
4445
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
4546
from ..utils.file_utils import find_files_matching_pattern
4647
from .stats import Stats
@@ -736,9 +737,9 @@ def text_generation(
736737
raise ValueError(
737738
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
738739
)
739-
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id not in self.eos_token_ids:
740+
if not verify_eos_tokens_in_tokenizer(self.eos_token_ids, self.tokenizer):
740741
raise ValueError(
741-
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must match with the model's eos_token_ids={self.eos_token_ids}."
742+
f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}."
742743
)
743744

744745
# Reset stats for a new generation

optimum/exporters/executorch/utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import List, Optional, Set
1616

1717
import torch
1818
from transformers import GenerationConfig, PretrainedConfig
@@ -65,3 +65,49 @@ def save_config_to_constant_methods(
6565

6666
# Combine with any additional kwargs and filter out None values
6767
return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None}
68+
69+
70+
def verify_eos_tokens_in_tokenizer(model_eos_ids: List[int], tokenizer) -> bool:
71+
"""
72+
Verifies that the model's EOS token IDs are present in the tokenizer's
73+
set of potential end-of-sequence tokens.
74+
75+
Args:
76+
model_eos_ids: A list of EOS token IDs recorded int the PTE file (the source of truth).
77+
tokenizer: The Hugging Face tokenizer instance to check.
78+
79+
Returns:
80+
True if at least one model EOS ID is found among the tokenizer's potential
81+
EOS tokens, False otherwise.
82+
"""
83+
if not model_eos_ids:
84+
print("Warning: model_eos_ids list is empty. No verification can be performed.")
85+
return True
86+
87+
candidate_eos_ids: Set[int] = set()
88+
89+
# 1. Check primary eos_token and pad_token attributes
90+
if tokenizer.eos_token_id is not None:
91+
candidate_eos_ids.add(tokenizer.eos_token_id)
92+
if tokenizer.pad_token_id is not None:
93+
candidate_eos_ids.add(tokenizer.pad_token_id)
94+
95+
# 2. Check all tokens listed in the special_tokens_map
96+
for token_string in tokenizer.special_tokens_map.values():
97+
if token_string:
98+
# Use convert_tokens_to_ids for robustness
99+
token_id = tokenizer.convert_tokens_to_ids(token_string)
100+
if isinstance(token_id, int):
101+
candidate_eos_ids.add(token_id)
102+
103+
# 3. Check added tokens for "end-of-X" patterns
104+
for token_id, added_token in tokenizer.added_tokens_decoder.items():
105+
token_str = added_token.content.lower()
106+
# Heuristic to find tokens that signify an end
107+
if "end" in token_str or token_str.startswith("</"):
108+
candidate_eos_ids.add(token_id)
109+
110+
# The check: is any "true" ID present in the candidate set?
111+
is_valid = any(model_id in candidate_eos_ids for model_id in model_eos_ids)
112+
113+
return is_valid
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import logging
18+
import os
19+
import unittest
20+
21+
import pytest
22+
import torchao
23+
import transformers
24+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
25+
from packaging.version import parse
26+
from transformers import AutoTokenizer
27+
from transformers.testing_utils import slow
28+
29+
from optimum.executorch import ExecuTorchModelForCausalLM
30+
31+
from ..utils import check_causal_lm_output_quality
32+
33+
34+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
35+
36+
37+
class ExecuTorchModelIntegrationTest(unittest.TestCase):
38+
def __init__(self, *args, **kwargs):
39+
super().__init__(*args, **kwargs)
40+
41+
@slow
42+
@pytest.mark.run_slow
43+
@pytest.mark.skipif(
44+
parse(transformers.__version__) < parse("4.52.0") or parse(torchao.__version__) < parse("0.11.0"),
45+
reason="Only available on transformers >= 4.52.0 and torchao >= 0.11.0",
46+
)
47+
def test_qwen3_embedding_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
48+
model_id = "Qwen/Qwen3-Embedding-0.6B"
49+
prompt = "Explain gravity"
50+
tokenizer = AutoTokenizer.from_pretrained(model_id)
51+
model = ExecuTorchModelForCausalLM.from_pretrained(
52+
model_id,
53+
recipe="xnnpack",
54+
attn_implementation="custom_sdpa",
55+
use_custom_kv_cache=True,
56+
**{"qlinear": True, "qembeeding": True},
57+
)
58+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
59+
self.assertIsInstance(model.model, ExecuTorchModule)
60+
generated_text = model.text_generation(
61+
tokenizer=tokenizer,
62+
prompt=prompt,
63+
max_seq_len=64,
64+
)
65+
logging.info(f"\nGenerated text:\n\t{generated_text}")
66+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
67+
68+
# Free memory before loading eager for quality check
69+
del model
70+
del tokenizer
71+
gc.collect()
72+
73+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

0 commit comments

Comments
 (0)