@@ -197,6 +197,7 @@ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
197
197
198
198
if use_deepspeed :
199
199
import deepspeed
200
+
200
201
self .ds_engine = deepspeed .init_inference (
201
202
model = self .gpt_inference .half (), # Transformers models
202
203
mp_size = 1 , # Number of GPU
@@ -233,6 +234,7 @@ def get_logits(
233
234
prompt = None ,
234
235
get_attns = False ,
235
236
return_latent = False ,
237
+ attn_mask_cond = None ,
236
238
attn_mask_text = None ,
237
239
attn_mask_mel = None ,
238
240
):
@@ -248,8 +250,11 @@ def get_logits(
248
250
if attn_mask_text is not None :
249
251
attn_mask = torch .cat ([attn_mask_text , attn_mask_mel ], dim = 1 )
250
252
if prompt is not None :
251
- attn_mask_prompt = torch .ones (prompt .shape [0 ], offset , dtype = torch .bool , device = emb .device )
252
- attn_mask = torch .cat ([attn_mask_prompt , attn_mask ], dim = 1 )
253
+ if attn_mask_cond is not None :
254
+ attn_mask = torch .cat ([attn_mask_cond , attn_mask ], dim = 1 )
255
+ else :
256
+ attn_mask_cond = torch .ones (prompt .shape [0 ], offset , dtype = torch .bool , device = emb .device )
257
+ attn_mask = torch .cat ([attn_mask_cond , attn_mask ], dim = 1 )
253
258
254
259
gpt_out = self .gpt (
255
260
inputs_embeds = emb ,
@@ -326,7 +331,7 @@ def get_prompts(self, prompt_codes):
326
331
prompt = F .pad (prompt , (0 , 1 ), value = self .stop_prompt_token )
327
332
return prompt
328
333
329
- def get_style_emb (self , cond_input , cond_lens = None , cond_seg_len = None , return_latent = False , sample = True ):
334
+ def get_style_emb (self , cond_input , return_latent = False ):
330
335
"""
331
336
cond_input: (b, 80, s) or (b, 1, 80, s)
332
337
conds: (b, 1024, s)
@@ -335,26 +340,7 @@ def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_la
335
340
if not return_latent :
336
341
if cond_input .ndim == 4 :
337
342
cond_input = cond_input .squeeze (1 )
338
- if sample :
339
- _len_secs = random .randint (2 , 6 ) # in secs
340
- cond_seg_len = int ((22050 / 1024 ) * _len_secs ) # in frames
341
- if cond_input .shape [- 1 ] >= cond_seg_len :
342
- new_conds = []
343
- for i in range (cond_input .shape [0 ]):
344
- cond_len = int (cond_lens [i ] / 1024 )
345
- if cond_len < cond_seg_len :
346
- start = 0
347
- else :
348
- start = random .randint (0 , cond_len - cond_seg_len )
349
- cond_vec = cond_input [i , :, start : start + cond_seg_len ]
350
- new_conds .append (cond_vec )
351
- conds = torch .stack (new_conds , dim = 0 )
352
- else :
353
- cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
354
- cond_frame_len = int ((22050 / 1024 ) * cond_seg_len )
355
- conds = cond_input [:, :, - cond_frame_len :]
356
-
357
- conds = self .conditioning_encoder (conds )
343
+ conds = self .conditioning_encoder (cond_input )
358
344
else :
359
345
# already computed
360
346
conds = cond_input .unsqueeze (1 )
@@ -366,22 +352,22 @@ def forward(
366
352
text_lengths ,
367
353
audio_codes ,
368
354
wav_lengths ,
369
- cond_lens = None ,
370
355
cond_mels = None ,
356
+ cond_idxs = None ,
371
357
cond_latents = None ,
372
- loss_weights = None ,
373
358
return_attentions = False ,
374
359
return_latent = False ,
375
360
):
376
361
"""
377
362
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
378
363
(actuated by `text_first`).
379
364
380
- cond_mels: MEL float tensor, (b, 1, 80,s)
381
365
text_inputs: long tensor, (b,t)
382
366
text_lengths: long tensor, (b,)
383
367
mel_inputs: long tensor, (b,m)
384
368
wav_lengths: long tensor, (b,)
369
+ cond_mels: MEL float tensor, (b, 1, 80,s)
370
+ cond_idxs: cond start and end indexs, (b, 2)
385
371
386
372
If return_attentions is specified, only logits are returned.
387
373
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
@@ -393,6 +379,11 @@ def forward(
393
379
max_text_len = text_lengths .max ()
394
380
code_lengths = torch .ceil (wav_lengths / self .code_stride_len ).long () + 3
395
381
382
+ if cond_idxs is not None :
383
+ # recompute cond idxs for mel lengths
384
+ for idx , l in enumerate (code_lengths ):
385
+ cond_idxs [idx ] = cond_idxs [idx ] / self .code_stride_len
386
+
396
387
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
397
388
max_mel_len = code_lengths .max ()
398
389
@@ -435,9 +426,16 @@ def forward(
435
426
)
436
427
437
428
# Set attn_mask
429
+ attn_mask_cond = None
438
430
attn_mask_text = None
439
431
attn_mask_mel = None
440
432
if not return_latent :
433
+ attn_mask_cond = torch .ones (
434
+ cond_mels .shape [0 ],
435
+ cond_mels .shape [- 1 ],
436
+ dtype = torch .bool ,
437
+ device = text_inputs .device ,
438
+ )
441
439
attn_mask_text = torch .ones (
442
440
text_inputs .shape [0 ],
443
441
text_inputs .shape [1 ],
@@ -451,6 +449,11 @@ def forward(
451
449
device = audio_codes .device ,
452
450
)
453
451
452
+ if cond_idxs is not None :
453
+ for idx , r in enumerate (cond_idxs ):
454
+ l = r [1 ] - r [0 ]
455
+ attn_mask_cond [idx , l :] = 0.0
456
+
454
457
for idx , l in enumerate (text_lengths ):
455
458
attn_mask_text [idx , l + 1 :] = 0.0
456
459
@@ -465,7 +468,7 @@ def forward(
465
468
466
469
# Compute speech conditioning input
467
470
if cond_latents is None :
468
- cond_latents = self .get_style_emb (cond_mels , cond_lens ).transpose (1 , 2 )
471
+ cond_latents = self .get_style_emb (cond_mels ).transpose (1 , 2 )
469
472
470
473
# Get logits
471
474
sub = - 5 # don't ask me why 😄
@@ -480,6 +483,7 @@ def forward(
480
483
prompt = cond_latents ,
481
484
get_attns = return_attentions ,
482
485
return_latent = return_latent ,
486
+ attn_mask_cond = attn_mask_cond ,
483
487
attn_mask_text = attn_mask_text ,
484
488
attn_mask_mel = attn_mask_mel ,
485
489
)
@@ -501,6 +505,13 @@ def forward(
501
505
0
502
506
], f" ❗ mel_targets does not contain stop token ({ self .stop_audio_token } ) in every row."
503
507
508
+ # ignore the loss for the segment used for conditioning
509
+ # coin flip for the segment to be ignored
510
+ if cond_idxs is not None :
511
+ cond_start = cond_idxs [idx , 0 ]
512
+ cond_end = cond_idxs [idx , 1 ]
513
+ mel_targets [idx , cond_start :cond_end ] = - 1
514
+
504
515
# Compute losses
505
516
loss_text = F .cross_entropy (
506
517
text_logits , text_targets .long (), ignore_index = - 1 , label_smoothing = self .label_smoothing
@@ -548,7 +559,7 @@ def generate(
548
559
bos_token_id = self .start_audio_token ,
549
560
pad_token_id = self .stop_audio_token ,
550
561
eos_token_id = self .stop_audio_token ,
551
- max_length = self .max_mel_tokens * 2 + self . max_prompt_tokens + self . max_text_tokens ,
562
+ max_length = self .max_mel_tokens ,
552
563
** hf_generate_kwargs ,
553
564
)
554
565
if "return_dict_in_generate" in hf_generate_kwargs :
@@ -561,7 +572,7 @@ def get_generator(self, fake_inputs, **hf_generate_kwargs):
561
572
bos_token_id = self .start_audio_token ,
562
573
pad_token_id = self .stop_audio_token ,
563
574
eos_token_id = self .stop_audio_token ,
564
- max_length = self .max_mel_tokens * 2 + self . max_prompt_tokens + self . max_text_tokens ,
575
+ max_length = self .max_mel_tokens ,
565
576
do_stream = True ,
566
577
** hf_generate_kwargs ,
567
578
)
0 commit comments