1
- import torch
1
+ import numpy as np
2
2
3
3
4
- # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
5
4
class FeatureExtractor :
6
5
def __init__ (
7
6
self ,
8
- device : str = "auto" ,
9
7
feature_size = 80 ,
10
8
sampling_rate = 16000 ,
11
9
hop_length = 160 ,
12
10
chunk_length = 30 ,
13
11
n_fft = 400 ,
14
12
):
15
- if device == "auto" :
16
- self .device = "cuda" if torch .cuda .is_available () else "cpu"
17
- else :
18
- self .device = device
19
13
self .n_fft = n_fft
20
14
self .hop_length = hop_length
21
15
self .chunk_length = chunk_length
@@ -25,24 +19,21 @@ def __init__(
25
19
self .sampling_rate = sampling_rate
26
20
self .mel_filters = self .get_mel_filters (
27
21
sampling_rate , n_fft , n_mels = feature_size
28
- )
22
+ ). astype ( "float32" )
29
23
30
24
@staticmethod
31
25
def get_mel_filters (sr , n_fft , n_mels = 128 ):
32
- """
33
- Implementation of librosa.filters.mel in Pytorch
34
- """
35
26
# Initialize the weights
36
27
n_mels = int (n_mels )
37
28
38
29
# Center freqs of each FFT bin
39
- fftfreqs = torch .fft .rfftfreq (n = n_fft , d = 1.0 / sr )
30
+ fftfreqs = np .fft .rfftfreq (n = n_fft , d = 1.0 / sr )
40
31
41
32
# 'Center freqs' of mel bands - uniformly spaced between limits
42
33
min_mel = 0.0
43
34
max_mel = 45.245640471924965
44
35
45
- mels = torch .linspace (min_mel , max_mel , n_mels + 2 )
36
+ mels = np .linspace (min_mel , max_mel , n_mels + 2 )
46
37
47
38
# Fill in the linear scale
48
39
f_min = 0.0
@@ -52,30 +43,159 @@ def get_mel_filters(sr, n_fft, n_mels=128):
52
43
# And now the nonlinear scale
53
44
min_log_hz = 1000.0 # beginning of log region (Hz)
54
45
min_log_mel = (min_log_hz - f_min ) / f_sp # same (Mels)
55
- logstep = torch .log (torch . tensor ( 6.4 ) ) / 27.0 # step size for log region
46
+ logstep = np .log (6.4 ) / 27.0 # step size for log region
56
47
57
48
# If we have vector data, vectorize
58
49
log_t = mels >= min_log_mel
59
- freqs [log_t ] = min_log_hz * torch .exp (logstep * (mels [log_t ] - min_log_mel ))
60
-
61
- mel_f = freqs
50
+ freqs [log_t ] = min_log_hz * np .exp (logstep * (mels [log_t ] - min_log_mel ))
62
51
63
- fdiff = torch .diff (mel_f )
64
- ramps = mel_f . view (- 1 , 1 ) - fftfreqs .view (1 , - 1 )
52
+ fdiff = np .diff (freqs )
53
+ ramps = freqs . reshape (- 1 , 1 ) - fftfreqs .reshape (1 , - 1 )
65
54
66
- lower = - ramps [:- 2 ] / fdiff [:- 1 ]. unsqueeze ( 1 )
67
- upper = ramps [2 :] / fdiff [1 :]. unsqueeze ( 1 )
55
+ lower = - ramps [:- 2 ] / np . expand_dims ( fdiff [:- 1 ], axis = 1 )
56
+ upper = ramps [2 :] / np . expand_dims ( fdiff [1 :], axis = 1 )
68
57
69
58
# Intersect them with each other and zero, vectorized across all i
70
- weights = torch .maximum (torch .zeros_like (lower ), torch .minimum (lower , upper ))
59
+ weights = np .maximum (np .zeros_like (lower ), np .minimum (lower , upper ))
71
60
72
61
# Slaney-style mel is scaled to be approx constant energy per channel
73
- enorm = 2.0 / (mel_f [2 : n_mels + 2 ] - mel_f [:n_mels ])
74
- weights *= enorm . unsqueeze ( 1 )
62
+ enorm = 2.0 / (freqs [2 : n_mels + 2 ] - freqs [:n_mels ])
63
+ weights *= np . expand_dims ( enorm , axis = 1 )
75
64
76
65
return weights
77
66
78
- def __call__ (self , waveform , padding = True , chunk_length = None , to_cpu = False ):
67
+ @staticmethod
68
+ def stft (
69
+ input_array : np .ndarray ,
70
+ n_fft : int ,
71
+ hop_length : int = None ,
72
+ win_length : int = None ,
73
+ window : np .ndarray = None ,
74
+ center : bool = True ,
75
+ mode : str = "reflect" ,
76
+ normalized : bool = False ,
77
+ onesided : bool = None ,
78
+ return_complex : bool = None ,
79
+ ):
80
+ # Default initialization for hop_length and win_length
81
+ hop_length = hop_length if hop_length is not None else n_fft // 4
82
+ win_length = win_length if win_length is not None else n_fft
83
+ input_is_complex = np .iscomplexobj (input_array )
84
+
85
+ # Determine if the output should be complex
86
+ return_complex = (
87
+ return_complex
88
+ if return_complex is not None
89
+ else (input_is_complex or (window is not None and np .iscomplexobj (window )))
90
+ )
91
+
92
+ if not return_complex and return_complex is None :
93
+ raise ValueError (
94
+ "stft requires the return_complex parameter for real inputs."
95
+ )
96
+
97
+ # Input checks
98
+ if not np .issubdtype (input_array .dtype , np .floating ) and not input_is_complex :
99
+ raise ValueError (
100
+ "stft: expected an array of floating point or complex values,"
101
+ f" got { input_array .dtype } "
102
+ )
103
+
104
+ if input_array .ndim > 2 or input_array .ndim < 1 :
105
+ raise ValueError (
106
+ f"stft: expected a 1D or 2D array, but got { input_array .ndim } D array"
107
+ )
108
+
109
+ # Handle 1D input
110
+ if input_array .ndim == 1 :
111
+ input_array = np .expand_dims (input_array , axis = 0 )
112
+ input_array_1d = True
113
+ else :
114
+ input_array_1d = False
115
+
116
+ # Center padding if required
117
+ if center :
118
+ pad_amount = n_fft // 2
119
+ input_array = np .pad (
120
+ input_array , ((0 , 0 ), (pad_amount , pad_amount )), mode = mode
121
+ )
122
+
123
+ batch , length = input_array .shape
124
+
125
+ # Additional input checks
126
+ if n_fft <= 0 or n_fft > length :
127
+ raise ValueError (
128
+ f"stft: expected 0 < n_fft <= { length } , but got n_fft={ n_fft } "
129
+ )
130
+
131
+ if hop_length <= 0 :
132
+ raise ValueError (
133
+ f"stft: expected hop_length > 0, but got hop_length={ hop_length } "
134
+ )
135
+
136
+ if win_length <= 0 or win_length > n_fft :
137
+ raise ValueError (
138
+ f"stft: expected 0 < win_length <= n_fft, but got win_length={ win_length } "
139
+ )
140
+
141
+ if window is not None :
142
+ if window .ndim != 1 or window .shape [0 ] != win_length :
143
+ raise ValueError (
144
+ f"stft: expected a 1D window array of size equal to win_length={ win_length } , "
145
+ f"but got window with size { window .shape } "
146
+ )
147
+
148
+ # Handle padding of the window if necessary
149
+ if win_length < n_fft :
150
+ left = (n_fft - win_length ) // 2
151
+ window_ = np .zeros (n_fft , dtype = window .dtype )
152
+ window_ [left : left + win_length ] = window
153
+ else :
154
+ window_ = window
155
+
156
+ # Calculate the number of frames
157
+ n_frames = 1 + (length - n_fft ) // hop_length
158
+
159
+ # Time to columns
160
+ input_array = np .lib .stride_tricks .as_strided (
161
+ input_array ,
162
+ (batch , n_frames , n_fft ),
163
+ (
164
+ input_array .strides [0 ],
165
+ hop_length * input_array .strides [1 ],
166
+ input_array .strides [1 ],
167
+ ),
168
+ )
169
+
170
+ if window_ is not None :
171
+ input_array = input_array * window_
172
+
173
+ # FFT and transpose
174
+ complex_fft = input_is_complex
175
+ onesided = onesided if onesided is not None else not complex_fft
176
+
177
+ if normalized :
178
+ norm = "ortho"
179
+ else :
180
+ norm = None
181
+
182
+ if complex_fft :
183
+ if onesided :
184
+ raise ValueError (
185
+ "Cannot have onesided output if window or input is complex"
186
+ )
187
+ output = np .fft .fft (input_array , n = n_fft , axis = - 1 , norm = norm )
188
+ else :
189
+ output = np .fft .rfft (input_array , n = n_fft , axis = - 1 , norm = norm )
190
+
191
+ output = output .transpose ((0 , 2 , 1 ))
192
+
193
+ if input_array_1d :
194
+ output = output .squeeze (0 )
195
+
196
+ return output if return_complex else np .real (output )
197
+
198
+ def __call__ (self , waveform : np .ndarray , padding = 160 , chunk_length = None ):
79
199
"""
80
200
Compute the log-Mel spectrogram of the provided audio.
81
201
"""
@@ -84,31 +204,27 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
84
204
self .n_samples = chunk_length * self .sampling_rate
85
205
self .nb_max_frames = self .n_samples // self .hop_length
86
206
87
- if waveform .dtype is not torch .float32 :
88
- waveform = waveform .to (torch .float32 )
89
-
90
- waveform = (
91
- waveform .to (self .device )
92
- if self .device == "cuda" and not waveform .is_cuda
93
- else waveform
94
- )
207
+ if waveform .dtype is not np .float32 :
208
+ waveform = waveform .astype (np .float32 )
95
209
96
210
if padding :
97
- waveform = torch . nn . functional . pad (waveform , (0 , self . n_samples ))
211
+ waveform = np . pad (waveform , (0 , padding ))
98
212
99
- window = torch . hann_window (self .n_fft ). to ( waveform . device )
213
+ window = np . hanning (self .n_fft + 1 )[: - 1 ]. astype ( "float32" )
100
214
101
- stft = torch .stft (
102
- waveform , self .n_fft , self .hop_length , window = window , return_complex = True
103
- )
104
- magnitudes = stft [..., :- 1 ].abs () ** 2
215
+ stft = self .stft (
216
+ waveform ,
217
+ self .n_fft ,
218
+ self .hop_length ,
219
+ window = window ,
220
+ return_complex = True ,
221
+ ).astype ("complex64" )
222
+ magnitudes = np .abs (stft [..., :- 1 ]) ** 2
105
223
106
- mel_spec = self .mel_filters . to ( waveform . device ) @ magnitudes
224
+ mel_spec = self .mel_filters @ magnitudes
107
225
108
- log_spec = torch . clamp ( mel_spec , min = 1e-10 ). log10 ( )
109
- log_spec = torch .maximum (log_spec , log_spec .max () - 8.0 )
226
+ log_spec = np . log10 ( np . clip ( mel_spec , a_min = 1e-10 , a_max = None ) )
227
+ log_spec = np .maximum (log_spec , log_spec .max () - 8.0 )
110
228
log_spec = (log_spec + 4.0 ) / 4.0
111
229
112
- # When the model is running on multiple GPUs, the output should be moved
113
- # to the CPU since we don't know which GPU will handle the next job.
114
- return log_spec .cpu () if to_cpu else log_spec
230
+ return log_spec
0 commit comments