Skip to content

Commit 712d527

Browse files
committed
update Bigvgan vocoder and F5-bigvgan version, trained on Emilia ZH&EN, 1.25m updates
1 parent dee0420 commit 712d527

File tree

14 files changed

+368
-180
lines changed

14 files changed

+368
-180
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "src/third_party/BigVGAN"]
2+
path = src/third_party/BigVGAN
3+
url = https://github.com/NVIDIA/BigVGAN.git

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@ cd F5-TTS
4646
pip install -e .
4747
```
4848

49-
### 3. Docker usage
49+
### 3. Init submodule( optional, if you want to change the vocoder from vocos to bigvgan)
50+
51+
```bash
52+
git submodule update --init --recursive
53+
```
54+
After that, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
55+
```python
56+
import sys
57+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
58+
```
59+
60+
### 4. Docker usage
5061
```bash
5162
# Build from Dockerfile
5263
docker build -t f5tts:v1 .

src/f5_tts/api.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,18 @@
11
import random
22
import sys
3-
import tqdm
43
from importlib.resources import files
54

65
import soundfile as sf
76
import torch
7+
import tqdm
88
from cached_path import cached_path
99

10+
from f5_tts.infer.utils_infer import (hop_length, infer_process, load_model,
11+
load_vocoder, preprocess_ref_audio_text,
12+
remove_silence_for_generated_wav,
13+
save_spectrogram, target_sample_rate)
1014
from f5_tts.model import DiT, UNetT
1115
from f5_tts.model.utils import seed_everything
12-
from f5_tts.infer.utils_infer import (
13-
load_vocoder,
14-
load_model,
15-
infer_process,
16-
remove_silence_for_generated_wav,
17-
save_spectrogram,
18-
preprocess_ref_audio_text,
19-
target_sample_rate,
20-
hop_length,
21-
)
2216

2317

2418
class F5TTS:
@@ -29,6 +23,7 @@ def __init__(
2923
vocab_file="",
3024
ode_method="euler",
3125
use_ema=True,
26+
vocoder_name="vocos",
3227
local_path=None,
3328
device=None,
3429
):
@@ -44,11 +39,11 @@ def __init__(
4439
)
4540

4641
# Load models
47-
self.load_vocoder_model(local_path)
42+
self.load_vocoder_model(vocoder_name, local_path)
4843
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
4944

50-
def load_vocoder_model(self, local_path):
51-
self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
45+
def load_vocoder_model(self, vocoder_name, local_path):
46+
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
5247

5348
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
5449
if model_type == "F5-TTS":

src/f5_tts/eval/eval_infer_batch.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,23 @@
1-
import sys
21
import os
2+
import sys
33

44
sys.path.append(os.getcwd())
55

6-
import time
7-
from tqdm import tqdm
86
import argparse
7+
import time
98
from importlib.resources import files
109

1110
import torch
1211
import torchaudio
1312
from accelerate import Accelerator
14-
from vocos import Vocos
13+
from tqdm import tqdm
1514

16-
from f5_tts.model import CFM, UNetT, DiT
15+
from f5_tts.eval.utils_eval import (get_inference_prompt,
16+
get_librispeech_test_clean_metainfo,
17+
get_seedtts_testset_metainfo)
18+
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
19+
from f5_tts.model import CFM, DiT, UNetT
1720
from f5_tts.model.utils import get_tokenizer
18-
from f5_tts.infer.utils_infer import load_checkpoint
19-
from f5_tts.eval.utils_eval import (
20-
get_seedtts_testset_metainfo,
21-
get_librispeech_test_clean_metainfo,
22-
get_inference_prompt,
23-
)
2421

2522
accelerator = Accelerator()
2623
device = f"cuda:{accelerator.process_index}"
@@ -31,8 +28,12 @@
3128
target_sample_rate = 24000
3229
n_mel_channels = 100
3330
hop_length = 256
31+
win_length = 1024
32+
n_fft = 1024
33+
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
3434
target_rms = 0.1
3535

36+
3637
tokenizer = "pinyin"
3738
rel_path = str(files("f5_tts").joinpath("../../"))
3839

@@ -123,14 +124,11 @@ def main():
123124

124125
# Vocoder model
125126
local = False
126-
if local:
127-
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
128-
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
129-
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
130-
vocos.load_state_dict(state_dict)
131-
vocos.eval()
132-
else:
133-
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
127+
if extract_backend == "vocos":
128+
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129+
elif extract_backend == "bigvgan":
130+
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
131+
vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path)
134132

135133
# Tokenizer
136134
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
@@ -139,17 +137,21 @@ def main():
139137
model = CFM(
140138
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
141139
mel_spec_kwargs=dict(
142-
target_sample_rate=target_sample_rate,
143-
n_mel_channels=n_mel_channels,
140+
n_fft=n_fft,
144141
hop_length=hop_length,
142+
win_length=win_length,
143+
n_mel_channels=n_mel_channels,
144+
target_sample_rate=target_sample_rate,
145+
extract_backend=extract_backend,
145146
),
146147
odeint_kwargs=dict(
147148
method=ode_method,
148149
),
149150
vocab_char_map=vocab_char_map,
150151
).to(device)
151152

152-
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
153+
dtype = torch.float16 if extract_backend == "vocos" else torch.float32
154+
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
153155

154156
if not os.path.exists(output_dir) and accelerator.is_main_process:
155157
os.makedirs(output_dir)
@@ -178,14 +180,18 @@ def main():
178180
no_ref_audio=no_ref_audio,
179181
seed=seed,
180182
)
181-
# Final result
182-
for i, gen in enumerate(generated):
183-
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
184-
gen_mel_spec = gen.permute(0, 2, 1)
185-
generated_wave = vocos.decode(gen_mel_spec.cpu())
186-
if ref_rms_list[i] < target_rms:
187-
generated_wave = generated_wave * ref_rms_list[i] / target_rms
188-
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
183+
# Final result
184+
for i, gen in enumerate(generated):
185+
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
186+
gen_mel_spec = gen.permute(0, 2, 1)
187+
if extract_backend == "vocos":
188+
generated_wave = vocoder.decode(gen_mel_spec.cpu())
189+
elif extract_backend == "bigvgan":
190+
generated_wave = vocoder(gen_mel_spec)
191+
192+
if ref_rms_list[i] < target_rms:
193+
generated_wave = generated_wave * ref_rms_list[i] / target_rms
194+
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
189195

190196
accelerator.wait_for_everyone()
191197
if accelerator.is_main_process:

src/f5_tts/eval/utils_eval.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import os
33
import random
44
import string
5-
from tqdm import tqdm
65

76
import torch
87
import torch.nn.functional as F
98
import torchaudio
9+
from tqdm import tqdm
1010

11+
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
1112
from f5_tts.model.modules import MelSpec
1213
from f5_tts.model.utils import convert_char_to_pinyin
13-
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
1414

1515

1616
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
@@ -74,8 +74,11 @@ def get_inference_prompt(
7474
tokenizer="pinyin",
7575
polyphone=True,
7676
target_sample_rate=24000,
77+
n_fft=1024,
78+
win_length=1024,
7779
n_mel_channels=100,
7880
hop_length=256,
81+
extract_backend="bigvgan",
7982
target_rms=0.1,
8083
use_truth_duration=False,
8184
infer_batch_size=1,
@@ -94,7 +97,12 @@ def get_inference_prompt(
9497
)
9598

9699
mel_spectrogram = MelSpec(
97-
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
100+
n_fft=n_fft,
101+
hop_length=hop_length,
102+
win_length=win_length,
103+
n_mel_channels=n_mel_channels,
104+
target_sample_rate=target_sample_rate,
105+
extract_backend=extract_backend,
98106
)
99107

100108
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):

src/f5_tts/infer/infer_cli.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,18 @@
22
import codecs
33
import os
44
import re
5-
from pathlib import Path
65
from importlib.resources import files
6+
from pathlib import Path
77

88
import numpy as np
99
import soundfile as sf
1010
import tomli
1111
from cached_path import cached_path
1212

13+
from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
14+
preprocess_ref_audio_text,
15+
remove_silence_for_generated_wav)
1316
from f5_tts.model import DiT, UNetT
14-
from f5_tts.infer.utils_infer import (
15-
load_vocoder,
16-
load_model,
17-
preprocess_ref_audio_text,
18-
infer_process,
19-
remove_silence_for_generated_wav,
20-
)
21-
2217

2318
parser = argparse.ArgumentParser(
2419
prog="python3 infer-cli.py",
@@ -70,6 +65,7 @@
7065
"--remove_silence",
7166
help="Remove silence.",
7267
)
68+
parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
7369
parser.add_argument(
7470
"--load_vocoder_from_local",
7571
action="store_true",
@@ -111,9 +107,14 @@
111107
speed = args.speed
112108
wave_path = Path(output_dir) / "infer_cli_out.wav"
113109
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114-
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
110+
if args.vocoder_name == "vocos":
111+
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
112+
elif args.vocoder_name == "bigvgan":
113+
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
115114

116-
vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
115+
vocoder = load_vocoder(
116+
vocoder_name=args.vocoder_name, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path
117+
)
117118

118119

119120
# load models
@@ -136,6 +137,12 @@
136137
ckpt_step = 1200000
137138
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
138139
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
140+
elif args.vocoder_name == "bigvgan": # TODO: need to test
141+
repo_name = "F5-TTS"
142+
exp_name = "F5TTS_Base_bigvgan"
143+
ckpt_step = 1250000
144+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
145+
139146

140147
print(f"Using {model}...")
141148
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)

0 commit comments

Comments
 (0)