Skip to content

Commit 8aacb81

Browse files
authored
Fix Tortoise load (#2791)
* Remove key prunning in tortoise * Make lint
1 parent b3472a7 commit 8aacb81

File tree

3 files changed

+6
-11
lines changed

3 files changed

+6
-11
lines changed

TTS/tts/models/tortoise.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import random
3-
import re
43
from contextlib import contextmanager
54
from dataclasses import dataclass
65
from time import time
@@ -876,16 +875,12 @@ def load_checkpoint(
876875
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")
877876

878877
if os.path.exists(ar_path):
879-
keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access
880878
# remove keys from the checkpoint that are not in the model
881879
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
882-
for key in list(checkpoint.keys()):
883-
for pat in keys_to_ignore:
884-
if re.search(pat, key) is not None:
885-
del checkpoint[key]
886-
break
887880

888-
self.autoregressive.load_state_dict(checkpoint, strict=strict)
881+
# strict set False
882+
# due to removed `bias` and `masked_bias` changes in Transformers
883+
self.autoregressive.load_state_dict(checkpoint, strict=False)
889884

890885
if os.path.exists(diff_path):
891886
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)

TTS/tts/utils/text/japanese/phonemizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
try:
88
import MeCab
9-
except ImportError:
10-
raise ImportError("Japanese requires mecab-python3 and unidic-lite.")
9+
except ImportError as e:
10+
raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e
1111
from num2words import num2words
1212

1313
_CONVRULES = [

tests/api_tests/test_synthesize_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ def test_synthesize():
1212
'tts --model_name "coqui_studio/en/Torcull Diarmuid/coqui_studio" '
1313
'--text "This is it" '
1414
f'--out_path "{output_path}"'
15-
)
15+
)

0 commit comments

Comments
 (0)