Skip to content

Commit 2dbca5e

Browse files
Use Silero VAD in Batched Mode (#936)
Replace Pyannote VAD with Silero to reduce code duplication and requirements
1 parent 574e256 commit 2dbca5e

12 files changed

+277
-508
lines changed

MANIFEST.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
include faster_whisper/assets/silero_vad.onnx
1+
include faster_whisper/assets/silero_encoder_v5.onnx
2+
include faster_whisper/assets/silero_decoder_v5.onnx
23
include requirements.txt
34
include requirements.conversion.txt
4-
include faster_whisper/assets/pyannote_vad_model.bin

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,6 @@ language_info = model.detect_language_multi_segment("audio.mp3")
178178

179179
### Batched faster-whisper
180180

181-
182-
The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-2 Clause license and integrates its VAD model to this library. We modify this implementation and also replaced the feature extraction with a faster torch-based implementation. Batched version improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference.
183-
184181
The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper.
185182

186183
```python

benchmark/evaluate_yt_commons.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import argparse
2+
import json
3+
import os
4+
5+
from io import BytesIO
6+
7+
from datasets import load_dataset
8+
from evaluate import load
9+
from pytubefix import YouTube
10+
from torch.utils.data import DataLoader
11+
from tqdm import tqdm
12+
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
13+
14+
from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
15+
16+
17+
def url_to_audio(row):
18+
buffer = BytesIO()
19+
yt = YouTube(row["link"])
20+
video = (
21+
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
22+
.order_by("bitrate")
23+
.desc()
24+
.first()
25+
)
26+
video.stream_to_buffer(buffer)
27+
buffer.seek(0)
28+
row["audio"] = decode_audio(buffer)
29+
return row
30+
31+
32+
parser = argparse.ArgumentParser(description="WER benchmark")
33+
parser.add_argument(
34+
"--audio_numb",
35+
type=int,
36+
default=None,
37+
help="Specify the number of validation audio files in the dataset."
38+
" Set to None to retrieve all audio files.",
39+
)
40+
args = parser.parse_args()
41+
42+
# define the evaluation metric
43+
wer_metric = load("wer")
44+
45+
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
46+
normalizer = EnglishTextNormalizer(json.load(f))
47+
48+
dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
49+
url_to_audio
50+
)
51+
dataset = iter(
52+
DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2)
53+
)
54+
55+
model = WhisperModel("large-v3", device="cuda")
56+
pipeline = BatchedInferencePipeline(model, device="cuda")
57+
58+
59+
all_transcriptions = []
60+
all_references = []
61+
# iterate over the dataset and run inference
62+
for i, row in tqdm(enumerate(dataset), desc="Evaluating..."):
63+
result, info = pipeline.transcribe(
64+
row["audio"][0],
65+
batch_size=8,
66+
word_timestamps=False,
67+
without_timestamps=True,
68+
)
69+
70+
all_transcriptions.append("".join(segment.text for segment in result))
71+
all_references.append(row["text"][0])
72+
if args.audio_numb and i == (args.audio_numb - 1):
73+
break
74+
75+
# normalize predictions and references
76+
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
77+
all_references = [normalizer(reference) for reference in all_references]
78+
79+
# compute the WER metric
80+
wer = 100 * wer_metric.compute(
81+
predictions=all_transcriptions, references=all_references
82+
)
83+
print("WER: %.3f" % wer)

benchmark/requirements.benchmark.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ evaluate
44
datasets
55
memory_profiler
66
py3nvml
7+
pytubefix
-16.9 MB
Binary file not shown.
520 KB
Binary file not shown.
697 KB
Binary file not shown.

faster_whisper/assets/silero_vad.onnx

-2.21 MB
Binary file not shown.

0 commit comments

Comments
 (0)