Skip to content

Commit 3e0ba86

Browse files
Remove torch dependency, Faster numpy Feature extraction (#1106)
1 parent 8f01aee commit 3e0ba86

File tree

6 files changed

+203
-118
lines changed

6 files changed

+203
-118
lines changed

faster_whisper/audio.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import av
1616
import numpy as np
17-
import torch
1817

1918

2019
def decode_audio(
@@ -72,9 +71,9 @@ def decode_audio(
7271
if split_stereo:
7372
left_channel = audio[0::2]
7473
right_channel = audio[1::2]
75-
return torch.from_numpy(left_channel), torch.from_numpy(right_channel)
74+
return left_channel, right_channel
7675

77-
return torch.from_numpy(audio)
76+
return audio
7877

7978

8079
def _ignore_invalid_frames(frames):
@@ -113,20 +112,12 @@ def pad_or_trim(array, length: int = 3000, *, axis: int = -1):
113112
"""
114113
Pad or trim the Mel features array to 3000, as expected by the encoder.
115114
"""
116-
axis = axis % array.ndim
117115
if array.shape[axis] > length:
118-
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
119-
return array[idx]
116+
array = array.take(indices=range(length), axis=axis)
120117

121118
if array.shape[axis] < length:
122-
pad_widths = (
123-
[
124-
0,
125-
]
126-
* array.ndim
127-
* 2
128-
)
129-
pad_widths[2 * axis] = length - array.shape[axis]
130-
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
119+
pad_widths = [(0, 0)] * array.ndim
120+
pad_widths[axis] = (0, length - array.shape[axis])
121+
array = np.pad(array, pad_widths)
131122

132123
return array

faster_whisper/feature_extractor.py

Lines changed: 161 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
1-
import torch
1+
import numpy as np
22

33

4-
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
54
class FeatureExtractor:
65
def __init__(
76
self,
8-
device: str = "auto",
97
feature_size=80,
108
sampling_rate=16000,
119
hop_length=160,
1210
chunk_length=30,
1311
n_fft=400,
1412
):
15-
if device == "auto":
16-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
17-
else:
18-
self.device = device
1913
self.n_fft = n_fft
2014
self.hop_length = hop_length
2115
self.chunk_length = chunk_length
@@ -25,24 +19,21 @@ def __init__(
2519
self.sampling_rate = sampling_rate
2620
self.mel_filters = self.get_mel_filters(
2721
sampling_rate, n_fft, n_mels=feature_size
28-
)
22+
).astype("float32")
2923

3024
@staticmethod
3125
def get_mel_filters(sr, n_fft, n_mels=128):
32-
"""
33-
Implementation of librosa.filters.mel in Pytorch
34-
"""
3526
# Initialize the weights
3627
n_mels = int(n_mels)
3728

3829
# Center freqs of each FFT bin
39-
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
30+
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
4031

4132
# 'Center freqs' of mel bands - uniformly spaced between limits
4233
min_mel = 0.0
4334
max_mel = 45.245640471924965
4435

45-
mels = torch.linspace(min_mel, max_mel, n_mels + 2)
36+
mels = np.linspace(min_mel, max_mel, n_mels + 2)
4637

4738
# Fill in the linear scale
4839
f_min = 0.0
@@ -52,30 +43,159 @@ def get_mel_filters(sr, n_fft, n_mels=128):
5243
# And now the nonlinear scale
5344
min_log_hz = 1000.0 # beginning of log region (Hz)
5445
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
55-
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
46+
logstep = np.log(6.4) / 27.0 # step size for log region
5647

5748
# If we have vector data, vectorize
5849
log_t = mels >= min_log_mel
59-
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
60-
61-
mel_f = freqs
50+
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
6251

63-
fdiff = torch.diff(mel_f)
64-
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
52+
fdiff = np.diff(freqs)
53+
ramps = freqs.reshape(-1, 1) - fftfreqs.reshape(1, -1)
6554

66-
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
67-
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
55+
lower = -ramps[:-2] / np.expand_dims(fdiff[:-1], axis=1)
56+
upper = ramps[2:] / np.expand_dims(fdiff[1:], axis=1)
6857

6958
# Intersect them with each other and zero, vectorized across all i
70-
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
59+
weights = np.maximum(np.zeros_like(lower), np.minimum(lower, upper))
7160

7261
# Slaney-style mel is scaled to be approx constant energy per channel
73-
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
74-
weights *= enorm.unsqueeze(1)
62+
enorm = 2.0 / (freqs[2 : n_mels + 2] - freqs[:n_mels])
63+
weights *= np.expand_dims(enorm, axis=1)
7564

7665
return weights
7766

78-
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
67+
@staticmethod
68+
def stft(
69+
input_array: np.ndarray,
70+
n_fft: int,
71+
hop_length: int = None,
72+
win_length: int = None,
73+
window: np.ndarray = None,
74+
center: bool = True,
75+
mode: str = "reflect",
76+
normalized: bool = False,
77+
onesided: bool = None,
78+
return_complex: bool = None,
79+
):
80+
# Default initialization for hop_length and win_length
81+
hop_length = hop_length if hop_length is not None else n_fft // 4
82+
win_length = win_length if win_length is not None else n_fft
83+
input_is_complex = np.iscomplexobj(input_array)
84+
85+
# Determine if the output should be complex
86+
return_complex = (
87+
return_complex
88+
if return_complex is not None
89+
else (input_is_complex or (window is not None and np.iscomplexobj(window)))
90+
)
91+
92+
if not return_complex and return_complex is None:
93+
raise ValueError(
94+
"stft requires the return_complex parameter for real inputs."
95+
)
96+
97+
# Input checks
98+
if not np.issubdtype(input_array.dtype, np.floating) and not input_is_complex:
99+
raise ValueError(
100+
"stft: expected an array of floating point or complex values,"
101+
f" got {input_array.dtype}"
102+
)
103+
104+
if input_array.ndim > 2 or input_array.ndim < 1:
105+
raise ValueError(
106+
f"stft: expected a 1D or 2D array, but got {input_array.ndim}D array"
107+
)
108+
109+
# Handle 1D input
110+
if input_array.ndim == 1:
111+
input_array = np.expand_dims(input_array, axis=0)
112+
input_array_1d = True
113+
else:
114+
input_array_1d = False
115+
116+
# Center padding if required
117+
if center:
118+
pad_amount = n_fft // 2
119+
input_array = np.pad(
120+
input_array, ((0, 0), (pad_amount, pad_amount)), mode=mode
121+
)
122+
123+
batch, length = input_array.shape
124+
125+
# Additional input checks
126+
if n_fft <= 0 or n_fft > length:
127+
raise ValueError(
128+
f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}"
129+
)
130+
131+
if hop_length <= 0:
132+
raise ValueError(
133+
f"stft: expected hop_length > 0, but got hop_length={hop_length}"
134+
)
135+
136+
if win_length <= 0 or win_length > n_fft:
137+
raise ValueError(
138+
f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}"
139+
)
140+
141+
if window is not None:
142+
if window.ndim != 1 or window.shape[0] != win_length:
143+
raise ValueError(
144+
f"stft: expected a 1D window array of size equal to win_length={win_length}, "
145+
f"but got window with size {window.shape}"
146+
)
147+
148+
# Handle padding of the window if necessary
149+
if win_length < n_fft:
150+
left = (n_fft - win_length) // 2
151+
window_ = np.zeros(n_fft, dtype=window.dtype)
152+
window_[left : left + win_length] = window
153+
else:
154+
window_ = window
155+
156+
# Calculate the number of frames
157+
n_frames = 1 + (length - n_fft) // hop_length
158+
159+
# Time to columns
160+
input_array = np.lib.stride_tricks.as_strided(
161+
input_array,
162+
(batch, n_frames, n_fft),
163+
(
164+
input_array.strides[0],
165+
hop_length * input_array.strides[1],
166+
input_array.strides[1],
167+
),
168+
)
169+
170+
if window_ is not None:
171+
input_array = input_array * window_
172+
173+
# FFT and transpose
174+
complex_fft = input_is_complex
175+
onesided = onesided if onesided is not None else not complex_fft
176+
177+
if normalized:
178+
norm = "ortho"
179+
else:
180+
norm = None
181+
182+
if complex_fft:
183+
if onesided:
184+
raise ValueError(
185+
"Cannot have onesided output if window or input is complex"
186+
)
187+
output = np.fft.fft(input_array, n=n_fft, axis=-1, norm=norm)
188+
else:
189+
output = np.fft.rfft(input_array, n=n_fft, axis=-1, norm=norm)
190+
191+
output = output.transpose((0, 2, 1))
192+
193+
if input_array_1d:
194+
output = output.squeeze(0)
195+
196+
return output if return_complex else np.real(output)
197+
198+
def __call__(self, waveform: np.ndarray, padding=160, chunk_length=None):
79199
"""
80200
Compute the log-Mel spectrogram of the provided audio.
81201
"""
@@ -84,31 +204,27 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
84204
self.n_samples = chunk_length * self.sampling_rate
85205
self.nb_max_frames = self.n_samples // self.hop_length
86206

87-
if waveform.dtype is not torch.float32:
88-
waveform = waveform.to(torch.float32)
89-
90-
waveform = (
91-
waveform.to(self.device)
92-
if self.device == "cuda" and not waveform.is_cuda
93-
else waveform
94-
)
207+
if waveform.dtype is not np.float32:
208+
waveform = waveform.astype(np.float32)
95209

96210
if padding:
97-
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
211+
waveform = np.pad(waveform, (0, padding))
98212

99-
window = torch.hann_window(self.n_fft).to(waveform.device)
213+
window = np.hanning(self.n_fft + 1)[:-1].astype("float32")
100214

101-
stft = torch.stft(
102-
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
103-
)
104-
magnitudes = stft[..., :-1].abs() ** 2
215+
stft = self.stft(
216+
waveform,
217+
self.n_fft,
218+
self.hop_length,
219+
window=window,
220+
return_complex=True,
221+
).astype("complex64")
222+
magnitudes = np.abs(stft[..., :-1]) ** 2
105223

106-
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
224+
mel_spec = self.mel_filters @ magnitudes
107225

108-
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
109-
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
226+
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
227+
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
110228
log_spec = (log_spec + 4.0) / 4.0
111229

112-
# When the model is running on multiple GPUs, the output should be moved
113-
# to the CPU since we don't know which GPU will handle the next job.
114-
return log_spec.cpu() if to_cpu else log_spec
230+
return log_spec

0 commit comments

Comments
 (0)