Skip to content

Commit 16ba377

Browse files
authored
Merge pull request #3086 from coqui-ai/xtts_trainer
XTTS v1.1 GPT Trainer
2 parents 1e15269 + 01839af commit 16ba377

File tree

14 files changed

+14009
-291
lines changed

14 files changed

+14009
-291
lines changed

.github/workflows/xtts_tests.yml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
name: xtts-tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
types: [opened, synchronize, reopened]
9+
jobs:
10+
check_skip:
11+
runs-on: ubuntu-latest
12+
if: "! contains(github.event.head_commit.message, '[ci skip]')"
13+
steps:
14+
- run: echo "${{ github.event.head_commit.message }}"
15+
16+
test:
17+
runs-on: ubuntu-latest
18+
strategy:
19+
fail-fast: false
20+
matrix:
21+
python-version: [3.9, "3.10", "3.11"]
22+
experimental: [false]
23+
steps:
24+
- uses: actions/checkout@v3
25+
- name: Set up Python ${{ matrix.python-version }}
26+
uses: actions/setup-python@v4
27+
with:
28+
python-version: ${{ matrix.python-version }}
29+
architecture: x64
30+
cache: 'pip'
31+
cache-dependency-path: 'requirements*'
32+
- name: check OS
33+
run: cat /etc/os-release
34+
- name: set ENV
35+
run: export TRAINER_TELEMETRY=0
36+
- name: Install dependencies
37+
run: |
38+
sudo apt-get update
39+
sudo apt-get install -y --no-install-recommends git make gcc
40+
sudo apt-get install espeak
41+
sudo apt-get install espeak-ng
42+
make system-deps
43+
- name: Install/upgrade Python setup deps
44+
run: python3 -m pip install --upgrade pip setuptools wheel
45+
- name: Replace scarf urls
46+
run: |
47+
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
48+
- name: Install TTS
49+
run: |
50+
python3 -m pip install .[all]
51+
python3 setup.py egg_info
52+
- name: Unit tests
53+
run: make test_xtts

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ test_tts: ## run tts tests.
2222
test_tts2: ## run tts tests.
2323
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2
2424

25+
test_xtts:
26+
nose2 -F -v -B --with-coverage --coverage TTS tests.xtts_tests
27+
2528
test_aux: ## run aux tests.
2629
nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests
2730
./run_bash_tests.sh

TTS/tts/layers/xtts/gpt.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
197197

198198
if use_deepspeed:
199199
import deepspeed
200+
200201
self.ds_engine = deepspeed.init_inference(
201202
model=self.gpt_inference.half(), # Transformers models
202203
mp_size=1, # Number of GPU
@@ -233,6 +234,7 @@ def get_logits(
233234
prompt=None,
234235
get_attns=False,
235236
return_latent=False,
237+
attn_mask_cond=None,
236238
attn_mask_text=None,
237239
attn_mask_mel=None,
238240
):
@@ -248,8 +250,11 @@ def get_logits(
248250
if attn_mask_text is not None:
249251
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
250252
if prompt is not None:
251-
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
252-
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
253+
if attn_mask_cond is not None:
254+
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
255+
else:
256+
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
257+
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
253258

254259
gpt_out = self.gpt(
255260
inputs_embeds=emb,
@@ -326,7 +331,7 @@ def get_prompts(self, prompt_codes):
326331
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
327332
return prompt
328333

329-
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
334+
def get_style_emb(self, cond_input, return_latent=False):
330335
"""
331336
cond_input: (b, 80, s) or (b, 1, 80, s)
332337
conds: (b, 1024, s)
@@ -335,26 +340,7 @@ def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_la
335340
if not return_latent:
336341
if cond_input.ndim == 4:
337342
cond_input = cond_input.squeeze(1)
338-
if sample:
339-
_len_secs = random.randint(2, 6) # in secs
340-
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
341-
if cond_input.shape[-1] >= cond_seg_len:
342-
new_conds = []
343-
for i in range(cond_input.shape[0]):
344-
cond_len = int(cond_lens[i] / 1024)
345-
if cond_len < cond_seg_len:
346-
start = 0
347-
else:
348-
start = random.randint(0, cond_len - cond_seg_len)
349-
cond_vec = cond_input[i, :, start : start + cond_seg_len]
350-
new_conds.append(cond_vec)
351-
conds = torch.stack(new_conds, dim=0)
352-
else:
353-
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
354-
cond_frame_len = int((22050 / 1024) * cond_seg_len)
355-
conds = cond_input[:, :, -cond_frame_len:]
356-
357-
conds = self.conditioning_encoder(conds)
343+
conds = self.conditioning_encoder(cond_input)
358344
else:
359345
# already computed
360346
conds = cond_input.unsqueeze(1)
@@ -366,22 +352,22 @@ def forward(
366352
text_lengths,
367353
audio_codes,
368354
wav_lengths,
369-
cond_lens=None,
370355
cond_mels=None,
356+
cond_idxs=None,
371357
cond_latents=None,
372-
loss_weights=None,
373358
return_attentions=False,
374359
return_latent=False,
375360
):
376361
"""
377362
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
378363
(actuated by `text_first`).
379364
380-
cond_mels: MEL float tensor, (b, 1, 80,s)
381365
text_inputs: long tensor, (b,t)
382366
text_lengths: long tensor, (b,)
383367
mel_inputs: long tensor, (b,m)
384368
wav_lengths: long tensor, (b,)
369+
cond_mels: MEL float tensor, (b, 1, 80,s)
370+
cond_idxs: cond start and end indexs, (b, 2)
385371
386372
If return_attentions is specified, only logits are returned.
387373
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
@@ -393,6 +379,11 @@ def forward(
393379
max_text_len = text_lengths.max()
394380
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
395381

382+
if cond_idxs is not None:
383+
# recompute cond idxs for mel lengths
384+
for idx, l in enumerate(code_lengths):
385+
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
386+
396387
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
397388
max_mel_len = code_lengths.max()
398389

@@ -435,9 +426,16 @@ def forward(
435426
)
436427

437428
# Set attn_mask
429+
attn_mask_cond = None
438430
attn_mask_text = None
439431
attn_mask_mel = None
440432
if not return_latent:
433+
attn_mask_cond = torch.ones(
434+
cond_mels.shape[0],
435+
cond_mels.shape[-1],
436+
dtype=torch.bool,
437+
device=text_inputs.device,
438+
)
441439
attn_mask_text = torch.ones(
442440
text_inputs.shape[0],
443441
text_inputs.shape[1],
@@ -451,6 +449,11 @@ def forward(
451449
device=audio_codes.device,
452450
)
453451

452+
if cond_idxs is not None:
453+
for idx, r in enumerate(cond_idxs):
454+
l = r[1] - r[0]
455+
attn_mask_cond[idx, l:] = 0.0
456+
454457
for idx, l in enumerate(text_lengths):
455458
attn_mask_text[idx, l + 1 :] = 0.0
456459

@@ -465,7 +468,7 @@ def forward(
465468

466469
# Compute speech conditioning input
467470
if cond_latents is None:
468-
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
471+
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
469472

470473
# Get logits
471474
sub = -5 # don't ask me why 😄
@@ -480,6 +483,7 @@ def forward(
480483
prompt=cond_latents,
481484
get_attns=return_attentions,
482485
return_latent=return_latent,
486+
attn_mask_cond=attn_mask_cond,
483487
attn_mask_text=attn_mask_text,
484488
attn_mask_mel=attn_mask_mel,
485489
)
@@ -501,6 +505,13 @@ def forward(
501505
0
502506
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
503507

508+
# ignore the loss for the segment used for conditioning
509+
# coin flip for the segment to be ignored
510+
if cond_idxs is not None:
511+
cond_start = cond_idxs[idx, 0]
512+
cond_end = cond_idxs[idx, 1]
513+
mel_targets[idx, cond_start:cond_end] = -1
514+
504515
# Compute losses
505516
loss_text = F.cross_entropy(
506517
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
@@ -548,7 +559,7 @@ def generate(
548559
bos_token_id=self.start_audio_token,
549560
pad_token_id=self.stop_audio_token,
550561
eos_token_id=self.stop_audio_token,
551-
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
562+
max_length=self.max_mel_tokens,
552563
**hf_generate_kwargs,
553564
)
554565
if "return_dict_in_generate" in hf_generate_kwargs:
@@ -561,7 +572,7 @@ def get_generator(self, fake_inputs, **hf_generate_kwargs):
561572
bos_token_id=self.start_audio_token,
562573
pad_token_id=self.stop_audio_token,
563574
eos_token_id=self.stop_audio_token,
564-
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
575+
max_length=self.max_mel_tokens,
565576
do_stream=True,
566577
**hf_generate_kwargs,
567578
)

TTS/tts/layers/xtts/hifigan_decoder.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
2+
import torchaudio
23
from torch import nn
34
from torch.nn import Conv1d, ConvTranspose1d
45
from torch.nn import functional as F
56
from torch.nn.utils import remove_weight_norm, weight_norm
6-
import torchaudio
77

88
from TTS.utils.io import load_fsspec
99

10-
1110
LRELU_SLOPE = 0.1
1211

1312

@@ -224,9 +223,7 @@ def __init__(
224223
self.cond_in_each_up_layer = cond_in_each_up_layer
225224

226225
# initial upsampling layers
227-
self.conv_pre = weight_norm(
228-
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
229-
)
226+
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
230227
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
231228
# upsampling layers
232229
self.ups = nn.ModuleList()
@@ -246,14 +243,10 @@ def __init__(
246243
self.resblocks = nn.ModuleList()
247244
for i in range(len(self.ups)):
248245
ch = upsample_initial_channel // (2 ** (i + 1))
249-
for _, (k, d) in enumerate(
250-
zip(resblock_kernel_sizes, resblock_dilation_sizes)
251-
):
246+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
252247
self.resblocks.append(resblock(ch, k, d))
253248
# post convolution layer
254-
self.conv_post = weight_norm(
255-
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
256-
)
249+
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
257250
if cond_channels > 0:
258251
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
259252

@@ -318,9 +311,7 @@ def inference(self, c):
318311
Tensor: [B, 1, T]
319312
"""
320313
c = c.to(self.conv_pre.weight.device)
321-
c = torch.nn.functional.pad(
322-
c, (self.inference_padding, self.inference_padding), "replicate"
323-
)
314+
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
324315
return self.forward(c)
325316

326317
def remove_weight_norm(self):
@@ -342,6 +333,7 @@ def load_checkpoint(
342333
assert not self.training
343334
self.remove_weight_norm()
344335

336+
345337
class SELayer(nn.Module):
346338
def __init__(self, channel, reduction=8):
347339
super(SELayer, self).__init__()
@@ -425,10 +417,8 @@ def forward(self, x):
425417
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
426418

427419

428-
429420
class ResNetSpeakerEncoder(nn.Module):
430-
"""This is copied from 🐸TTS to remove it from the dependencies.
431-
"""
421+
"""This is copied from 🐸TTS to remove it from the dependencies."""
432422

433423
# pylint: disable=W0102
434424
def __init__(
@@ -620,6 +610,7 @@ def load_checkpoint(
620610
return criterion, state["step"]
621611
return criterion
622612

613+
623614
class HifiDecoder(torch.nn.Module):
624615
def __init__(
625616
self,
@@ -724,9 +715,7 @@ def inference(self, c, g):
724715
"""
725716
return self.forward(c, g=g)
726717

727-
def load_checkpoint(
728-
self, checkpoint_path, eval=False
729-
): # pylint: disable=unused-argument, redefined-builtin
718+
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
730719
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
731720
# remove unused keys
732721
state = state["model"]

0 commit comments

Comments
 (0)