Skip to content

Commit eb83902

Browse files
JiltsebhargunmujralMahmoudAshraf97
authored
New PR for Faster Whisper: Batching Support, Speed Boosts, and Quality Enhancements (#856)
Batching Support, Speed Boosts, and Quality Enhancements --------- Co-authored-by: Hargun Mujral <[email protected]> Co-authored-by: MahmoudAshraf97 <[email protected]>
1 parent fbcf58b commit eb83902

File tree

13 files changed

+1693
-419
lines changed

13 files changed

+1693
-419
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include faster_whisper/assets/silero_vad.onnx
22
include requirements.txt
33
include requirements.conversion.txt
4+
include faster_whisper/assets/pyannote_vad_model.bin

README.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")
6969

7070
* Python 3.8 or greater
7171

72-
Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package.
7372

7473
### GPU
7574

@@ -166,6 +165,35 @@ for segment in segments:
166165
segments, _ = model.transcribe("audio.mp3")
167166
segments = list(segments) # The transcription will actually run here.
168167
```
168+
169+
### multi-segment language detection
170+
171+
To directly use the model for improved language detection, the following code snippet can be used:
172+
173+
```python
174+
from faster_whisper import WhisperModel
175+
model = WhisperModel("medium", device="cuda", compute_type="float16")
176+
language_info = model.detect_language_multi_segment("audio.mp3")
177+
```
178+
179+
### Batched faster-whisper
180+
181+
182+
The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-2 Clause license and integrates its VAD model to this library. We modify this implementation and also replaced the feature extraction with a faster torch-based implementation. Batched version improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference.
183+
184+
The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper.
185+
186+
```python
187+
from faster_whisper import WhisperModel, BatchedInferencePipeline
188+
189+
model = WhisperModel("medium", device="cuda", compute_type="float16")
190+
batched_model = BatchedInferencePipeline(model=model)
191+
segments, info = batched_model.transcribe("audio.mp3", batch_size=16)
192+
193+
for segment in segments:
194+
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
195+
```
196+
169197
### Faster Distil-Whisper
170198

171199
The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3)

benchmark/wer_benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import json
3+
import os
34

45
from datasets import load_dataset
56
from evaluate import load
@@ -26,7 +27,9 @@
2627

2728
# define the evaluation metric
2829
wer_metric = load("wer")
29-
normalizer = EnglishTextNormalizer(json.load(open("normalizer.json")))
30+
31+
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
32+
normalizer = EnglishTextNormalizer(json.load(f))
3033

3134

3235
def inference(batch):

faster_whisper/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from faster_whisper.audio import decode_audio
2-
from faster_whisper.transcribe import WhisperModel
2+
from faster_whisper.transcribe import BatchedInferencePipeline, WhisperModel
33
from faster_whisper.utils import available_models, download_model, format_timestamp
44
from faster_whisper.version import __version__
55

66
__all__ = [
77
"available_models",
88
"decode_audio",
99
"WhisperModel",
10+
"BatchedInferencePipeline",
1011
"download_model",
1112
"format_timestamp",
1213
"__version__",
16.9 MB
Binary file not shown.

faster_whisper/audio.py

Lines changed: 22 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,7 @@
1-
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
2-
3-
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
4-
system dependencies. FFmpeg does not need to be installed on the system.
5-
6-
However, the API is quite low-level so we need to manipulate audio frames directly.
7-
"""
8-
9-
import gc
10-
import io
11-
import itertools
12-
131
from typing import BinaryIO, Union
142

15-
import av
16-
import numpy as np
3+
import torch
4+
import torchaudio
175

186

197
def decode_audio(
@@ -29,91 +17,42 @@ def decode_audio(
2917
split_stereo: Return separate left and right channels.
3018
3119
Returns:
32-
A float32 Numpy array.
20+
A float32 Torch Tensor.
3321
3422
If `split_stereo` is enabled, the function returns a 2-tuple with the
3523
separated left and right channels.
3624
"""
37-
resampler = av.audio.resampler.AudioResampler(
38-
format="s16",
39-
layout="mono" if not split_stereo else "stereo",
40-
rate=sampling_rate,
41-
)
42-
43-
raw_buffer = io.BytesIO()
44-
dtype = None
4525

46-
with av.open(input_file, mode="r", metadata_errors="ignore") as container:
47-
frames = container.decode(audio=0)
48-
frames = _ignore_invalid_frames(frames)
49-
frames = _group_frames(frames, 500000)
50-
frames = _resample_frames(frames, resampler)
51-
52-
for frame in frames:
53-
array = frame.to_ndarray()
54-
dtype = array.dtype
55-
raw_buffer.write(array)
56-
57-
# It appears that some objects related to the resampler are not freed
58-
# unless the garbage collector is manually run.
59-
del resampler
60-
gc.collect()
61-
62-
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
63-
64-
# Convert s16 back to f32.
65-
audio = audio.astype(np.float32) / 32768.0
26+
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T
6627

28+
if audio_sf != sampling_rate:
29+
waveform = torchaudio.functional.resample(
30+
waveform, orig_freq=audio_sf, new_freq=sampling_rate
31+
)
6732
if split_stereo:
68-
left_channel = audio[0::2]
69-
right_channel = audio[1::2]
70-
return left_channel, right_channel
71-
72-
return audio
73-
74-
75-
def _ignore_invalid_frames(frames):
76-
iterator = iter(frames)
77-
78-
while True:
79-
try:
80-
yield next(iterator)
81-
except StopIteration:
82-
break
83-
except av.error.InvalidDataError:
84-
continue
85-
86-
87-
def _group_frames(frames, num_samples=None):
88-
fifo = av.audio.fifo.AudioFifo()
89-
90-
for frame in frames:
91-
frame.pts = None # Ignore timestamp check.
92-
fifo.write(frame)
93-
94-
if num_samples is not None and fifo.samples >= num_samples:
95-
yield fifo.read()
96-
97-
if fifo.samples > 0:
98-
yield fifo.read()
99-
33+
return waveform[0], waveform[1]
10034

101-
def _resample_frames(frames, resampler):
102-
# Add None to flush the resampler.
103-
for frame in itertools.chain(frames, [None]):
104-
yield from resampler.resample(frame)
35+
return waveform.mean(0)
10536

10637

10738
def pad_or_trim(array, length: int, *, axis: int = -1):
10839
"""
10940
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
11041
"""
42+
axis = axis % array.ndim
11143
if array.shape[axis] > length:
112-
array = array.take(indices=range(length), axis=axis)
44+
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
45+
return array[idx]
11346

11447
if array.shape[axis] < length:
115-
pad_widths = [(0, 0)] * array.ndim
116-
pad_widths[axis] = (0, length - array.shape[axis])
117-
array = np.pad(array, pad_widths)
48+
pad_widths = (
49+
[
50+
0,
51+
]
52+
* array.ndim
53+
* 2
54+
)
55+
pad_widths[2 * axis] = length - array.shape[axis]
56+
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
11857

11958
return array

0 commit comments

Comments
 (0)