Skip to content

Commit 1195359

Browse files
authored
Filter out non_speech_tokens in suppressed tokens (#898)
* Filter out non_speech_tokens in suppressed tokens
1 parent c22db51 commit 1195359

File tree

3 files changed

+159
-11
lines changed

3 files changed

+159
-11
lines changed

faster_whisper/tokenizer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,42 @@ def decode_with_timestamps(self, tokens: List[int]) -> str:
105105
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
106106
)
107107

108+
@cached_property
109+
def non_speech_tokens(self) -> Tuple[int]:
110+
"""
111+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
112+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
113+
114+
- ♪♪♪
115+
- ( SPEAKING FOREIGN LANGUAGE )
116+
- [DAVID] Hey there,
117+
118+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
119+
"""
120+
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
121+
symbols += (
122+
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
123+
)
124+
125+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
126+
# In case they're multiple tokens, suppress the first token, which is safe because:
127+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
128+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
129+
miscellaneous = set("♩♪♫♬♭♮♯")
130+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
131+
132+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
133+
result = {self.encode(" -")[0], self.encode(" '")[0]}
134+
for symbol in symbols + list(miscellaneous):
135+
for tokens in [
136+
self.encode(symbol),
137+
self.encode(" " + symbol),
138+
]:
139+
if len(tokens) == 1 or symbol in miscellaneous:
140+
result.add(tokens[0])
141+
142+
return tuple(sorted(result))
143+
108144
def split_to_word_tokens(
109145
self, tokens: List[int]
110146
) -> Tuple[List[str], List[List[int]]]:

faster_whisper/transcribe.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def transcribe(
277277
prefix: Optional text to provide as a prefix for the first window.
278278
suppress_blank: Suppress blank outputs at the beginning of the sampling.
279279
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
280-
of symbols as defined in the model config.json file.
280+
of symbols as defined in `tokenizer.non_speech_tokens()`
281281
without_timestamps: Only sample text tokens.
282282
max_initial_timestamp: The initial timestamp cannot be later than this.
283283
word_timestamps: Extract word-level timestamps using the cross-attention pattern
@@ -462,7 +462,11 @@ def transcribe(
462462
initial_prompt=initial_prompt,
463463
prefix=prefix,
464464
suppress_blank=suppress_blank,
465-
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
465+
suppress_tokens=(
466+
get_suppressed_tokens(tokenizer, suppress_tokens)
467+
if suppress_tokens
468+
else suppress_tokens
469+
),
466470
without_timestamps=without_timestamps,
467471
max_initial_timestamp=max_initial_timestamp,
468472
word_timestamps=word_timestamps,
@@ -488,7 +492,6 @@ def transcribe(
488492
vad_options=vad_parameters,
489493
all_language_probs=all_language_probs,
490494
)
491-
492495
return segments, info
493496

494497
def generate_segments(
@@ -1227,15 +1230,16 @@ def get_compression_ratio(text: str) -> float:
12271230

12281231
def get_suppressed_tokens(
12291232
tokenizer: Tokenizer,
1230-
suppress_tokens: Optional[List[int]],
1233+
suppress_tokens: Tuple[int],
12311234
) -> Optional[List[int]]:
1232-
if not suppress_tokens or -1 in suppress_tokens:
1233-
return suppress_tokens
1234-
1235-
suppress_tokens = list(suppress_tokens)
1235+
if -1 in suppress_tokens:
1236+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
1237+
suppress_tokens.extend(tokenizer.non_speech_tokens)
1238+
elif suppress_tokens is None or len(suppress_tokens) == 0:
1239+
suppress_tokens = [] # interpret empty string as an empty list
1240+
else:
1241+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
12361242

1237-
# Ensure the following special tokens are suppressed when the user does
1238-
# not use the default set (-1).
12391243
suppress_tokens.extend(
12401244
[
12411245
tokenizer.transcribe,
@@ -1246,7 +1250,7 @@ def get_suppressed_tokens(
12461250
]
12471251
)
12481252

1249-
return sorted(set(suppress_tokens))
1253+
return tuple(sorted(set(suppress_tokens)))
12501254

12511255

12521256
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:

tests/test_transcribe.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22

33
from faster_whisper import WhisperModel, decode_audio
4+
from faster_whisper.tokenizer import Tokenizer
5+
from faster_whisper.transcribe import get_suppressed_tokens
46

57

68
def test_supported_languages():
@@ -97,3 +99,109 @@ def test_stereo_diarization(data_dir):
9799
segments, _ = model.transcribe(right)
98100
transcription = "".join(segment.text for segment in segments).strip()
99101
assert transcription == "The horizon seems extremely distant."
102+
103+
104+
def test_suppressed_tokens_minus_1():
105+
model = WhisperModel("tiny.en")
106+
107+
tokenizer = Tokenizer(model.hf_tokenizer, False)
108+
tokens = get_suppressed_tokens(tokenizer, [-1])
109+
assert tokens == (
110+
1,
111+
2,
112+
7,
113+
8,
114+
9,
115+
10,
116+
14,
117+
25,
118+
26,
119+
27,
120+
28,
121+
29,
122+
31,
123+
58,
124+
59,
125+
60,
126+
61,
127+
62,
128+
63,
129+
90,
130+
91,
131+
92,
132+
93,
133+
357,
134+
366,
135+
438,
136+
532,
137+
685,
138+
705,
139+
796,
140+
930,
141+
1058,
142+
1220,
143+
1267,
144+
1279,
145+
1303,
146+
1343,
147+
1377,
148+
1391,
149+
1635,
150+
1782,
151+
1875,
152+
2162,
153+
2361,
154+
2488,
155+
3467,
156+
4008,
157+
4211,
158+
4600,
159+
4808,
160+
5299,
161+
5855,
162+
6329,
163+
7203,
164+
9609,
165+
9959,
166+
10563,
167+
10786,
168+
11420,
169+
11709,
170+
11907,
171+
13163,
172+
13697,
173+
13700,
174+
14808,
175+
15306,
176+
16410,
177+
16791,
178+
17992,
179+
19203,
180+
19510,
181+
20724,
182+
22305,
183+
22935,
184+
27007,
185+
30109,
186+
30420,
187+
33409,
188+
34949,
189+
40283,
190+
40493,
191+
40549,
192+
47282,
193+
49146,
194+
50257,
195+
50357,
196+
50358,
197+
50359,
198+
50360,
199+
)
200+
201+
202+
def test_suppressed_tokens_minus_value():
203+
model = WhisperModel("tiny.en")
204+
205+
tokenizer = Tokenizer(model.hf_tokenizer, False)
206+
tokens = get_suppressed_tokens(tokenizer, [13])
207+
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)

0 commit comments

Comments
 (0)