Skip to content

Commit c844b65

Browse files
authored
Inference API for 🐶Bark (coqui-ai#2685)
* Add bark requirements * Draft Bark implementation * Download HF models * Update synthesizer * Add bark model * Make style * Update pylintrc * Update model URLs * Update Bark Config * Fix here and ther * Make style * Make lint * Update requirements * Update requirements
1 parent 4cf8652 commit c844b65

File tree

18 files changed

+1757
-101
lines changed

18 files changed

+1757
-101
lines changed

.pylintrc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ disable=missing-docstring,
169169
comprehension-escape,
170170
duplicate-code,
171171
not-callable,
172-
import-outside-toplevel
172+
import-outside-toplevel,
173+
logging-fstring-interpolation,
174+
logging-not-lazy
173175

174176
# Enable the message, report, category or checker with the given id(s). You can
175177
# either give multiple identifier separated by comma (,) or put this option

TTS/.models.json

Lines changed: 86 additions & 73 deletions
Large diffs are not rendered by default.

TTS/api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def list_models():
346346

347347
def download_model_by_name(self, model_name: str):
348348
model_path, config_path, model_item = self.manager.download_model(model_name)
349-
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
349+
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
350350
# return model directory if there are multiple files
351351
# we assume that the model knows how to load itself
352352
return None, None, None, None, model_path
@@ -584,6 +584,8 @@ def tts_to_file(
584584
Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
585585
file_path (str, optional):
586586
Output file path. Defaults to "output.wav".
587+
kwargs (dict, optional):
588+
Additional arguments for the model.
587589
"""
588590
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
589591

TTS/bin/synthesize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def main():
356356
vc_config_path = config_path
357357

358358
# tts model with multiple files to be loaded from the directory path
359-
if model_item.get("author", None) == "fairseq" or isinstance(model_item["github_rls_url"], list):
359+
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
360360
model_dir = model_path
361361
tts_path = None
362362
tts_config_path = None

TTS/tts/configs/bark_config.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import os
2+
from dataclasses import dataclass
3+
from typing import Dict
4+
5+
from TTS.tts.configs.shared_configs import BaseTTSConfig
6+
from TTS.tts.layers.bark.model import GPTConfig
7+
from TTS.tts.layers.bark.model_fine import FineGPTConfig
8+
from TTS.tts.models.bark import BarkAudioConfig
9+
from TTS.utils.generic_utils import get_user_data_dir
10+
11+
12+
@dataclass
13+
class BarkConfig(BaseTTSConfig):
14+
"""Bark TTS configuration
15+
16+
Args:
17+
model (str): model name that registers the model.
18+
audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig().
19+
num_chars (int): number of characters in the alphabet. Defaults to 0.
20+
semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig().
21+
fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig().
22+
coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig().
23+
CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024.
24+
SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9.
25+
SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000.
26+
CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024.
27+
N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2.
28+
N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8.
29+
COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75.
30+
SAMPLE_RATE (int): sample rate. Defaults to 24_000.
31+
USE_SMALLER_MODELS (bool): use smaller models. Defaults to False.
32+
TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048.
33+
SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000.
34+
TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048.
35+
TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049.
36+
TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050.
37+
SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051.
38+
COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048.
39+
COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050.
40+
REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree".
41+
REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None.
42+
LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None.
43+
SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None.
44+
CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir().
45+
DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir().
46+
"""
47+
48+
model: str = "bark"
49+
audio: BarkAudioConfig = BarkAudioConfig()
50+
num_chars: int = 0
51+
semantic_config: GPTConfig = GPTConfig()
52+
fine_config: FineGPTConfig = FineGPTConfig()
53+
coarse_config: GPTConfig = GPTConfig()
54+
CONTEXT_WINDOW_SIZE: int = 1024
55+
SEMANTIC_RATE_HZ: float = 49.9
56+
SEMANTIC_VOCAB_SIZE: int = 10_000
57+
CODEBOOK_SIZE: int = 1024
58+
N_COARSE_CODEBOOKS: int = 2
59+
N_FINE_CODEBOOKS: int = 8
60+
COARSE_RATE_HZ: int = 75
61+
SAMPLE_RATE: int = 24_000
62+
USE_SMALLER_MODELS: bool = False
63+
64+
TEXT_ENCODING_OFFSET: int = 10_048
65+
SEMANTIC_PAD_TOKEN: int = 10_000
66+
TEXT_PAD_TOKEN: int = 129_595
67+
SEMANTIC_INFER_TOKEN: int = 129_599
68+
COARSE_SEMANTIC_PAD_TOKEN: int = 12_048
69+
COARSE_INFER_TOKEN: int = 12_050
70+
71+
REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/"
72+
REMOTE_MODEL_PATHS: Dict = None
73+
LOCAL_MODEL_PATHS: Dict = None
74+
SMALL_REMOTE_MODEL_PATHS: Dict = None
75+
CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0"))
76+
DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers"))
77+
78+
def __post_init__(self):
79+
self.REMOTE_MODEL_PATHS = {
80+
"text": {
81+
"path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"),
82+
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
83+
},
84+
"coarse": {
85+
"path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"),
86+
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
87+
},
88+
"fine": {
89+
"path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"),
90+
"checksum": "59d184ed44e3650774a2f0503a48a97b",
91+
},
92+
}
93+
self.LOCAL_MODEL_PATHS = {
94+
"text": os.path.join(self.CACHE_DIR, "text_2.pt"),
95+
"coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"),
96+
"fine": os.path.join(self.CACHE_DIR, "fine_2.pt"),
97+
"hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"),
98+
"hubert": os.path.join(self.CACHE_DIR, "hubert.pt"),
99+
}
100+
self.SMALL_REMOTE_MODEL_PATHS = {
101+
"text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")},
102+
"coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")},
103+
"fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")},
104+
}
105+
self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init

TTS/tts/layers/bark/__init__.py

Whitespace-only changes.

TTS/tts/layers/bark/hubert/__init__.py

Whitespace-only changes.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
2+
3+
import os.path
4+
import shutil
5+
import urllib.request
6+
7+
import huggingface_hub
8+
9+
10+
class HubertManager:
11+
@staticmethod
12+
def make_sure_hubert_installed(
13+
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
14+
):
15+
if not os.path.isfile(model_path):
16+
print("Downloading HuBERT base model")
17+
urllib.request.urlretrieve(download_url, model_path)
18+
print("Downloaded HuBERT")
19+
return model_path
20+
return None
21+
22+
@staticmethod
23+
def make_sure_tokenizer_installed(
24+
model: str = "quantifier_hubert_base_ls960_14.pth",
25+
repo: str = "GitMylo/bark-voice-cloning",
26+
model_path: str = "",
27+
):
28+
model_dir = os.path.dirname(model_path)
29+
if not os.path.isfile(model_path):
30+
print("Downloading HuBERT custom tokenizer")
31+
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
32+
shutil.move(os.path.join(model_dir, model), model_path)
33+
print("Downloaded tokenizer")
34+
return model_path
35+
return None
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Modified HuBERT model without kmeans.
3+
Original author: https://github.com/lucidrains/
4+
Modified by: https://www.github.com/gitmylo/
5+
License: MIT
6+
"""
7+
8+
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
9+
10+
import logging
11+
from pathlib import Path
12+
13+
import fairseq
14+
import torch
15+
from einops import pack, unpack
16+
from torch import nn
17+
from torchaudio.functional import resample
18+
19+
logging.root.setLevel(logging.ERROR)
20+
21+
22+
def round_down_nearest_multiple(num, divisor):
23+
return num // divisor * divisor
24+
25+
26+
def curtail_to_multiple(t, mult, from_left=False):
27+
data_len = t.shape[-1]
28+
rounded_seq_len = round_down_nearest_multiple(data_len, mult)
29+
seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None)
30+
return t[..., seq_slice]
31+
32+
33+
def exists(val):
34+
return val is not None
35+
36+
37+
def default(val, d):
38+
return val if exists(val) else d
39+
40+
41+
class CustomHubert(nn.Module):
42+
"""
43+
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
44+
or you can train your own
45+
"""
46+
47+
def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None):
48+
super().__init__()
49+
self.target_sample_hz = target_sample_hz
50+
self.seq_len_multiple_of = seq_len_multiple_of
51+
self.output_layer = output_layer
52+
53+
if device is not None:
54+
self.to(device)
55+
56+
model_path = Path(checkpoint_path)
57+
58+
assert model_path.exists(), f"path {checkpoint_path} does not exist"
59+
60+
checkpoint = torch.load(checkpoint_path)
61+
load_model_input = {checkpoint_path: checkpoint}
62+
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
63+
64+
if device is not None:
65+
model[0].to(device)
66+
67+
self.model = model[0]
68+
self.model.eval()
69+
70+
@property
71+
def groups(self):
72+
return 1
73+
74+
@torch.no_grad()
75+
def forward(self, wav_input, flatten=True, input_sample_hz=None):
76+
device = wav_input.device
77+
78+
if exists(input_sample_hz):
79+
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
80+
81+
if exists(self.seq_len_multiple_of):
82+
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
83+
84+
embed = self.model(
85+
wav_input,
86+
features_only=True,
87+
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
88+
output_layer=self.output_layer,
89+
)
90+
91+
embed, packed_shape = pack([embed["x"]], "* d")
92+
93+
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
94+
95+
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
96+
97+
if flatten:
98+
return codebook_indices
99+
100+
(codebook_indices,) = unpack(codebook_indices, packed_shape, "*")
101+
return codebook_indices

0 commit comments

Comments
 (0)