Skip to content

Commit c883a75

Browse files
authored
fix: missing vocab_config in from_session (#144)
1 parent 1c0aff5 commit c883a75

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/kokoro_onnx/__init__.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
model_path: str,
3131
voices_path: str,
3232
espeak_config: EspeakConfig | None = None,
33-
vocab_config: dict | str = None,
33+
vocab_config: dict | str | None = None,
3434
):
3535
# Show useful information for bug reports
3636
log.debug(
@@ -56,15 +56,7 @@ def __init__(
5656
self.sess = rt.InferenceSession(model_path, providers=providers)
5757
self.voices: np.ndarray = np.load(voices_path)
5858

59-
vocab = None
60-
61-
if isinstance(vocab_config, str):
62-
with open(vocab_config, "r", encoding="utf-8") as fp:
63-
config = json.load(fp)
64-
vocab = config["vocab"]
65-
elif isinstance(vocab, dict):
66-
vocab = vocab["vocab"]
67-
59+
vocab = self._load_vocab(vocab_config)
6860
self.tokenizer = Tokenizer(espeak_config, vocab=vocab)
6961

7062
@classmethod
@@ -73,15 +65,36 @@ def from_session(
7365
session: rt.InferenceSession,
7466
voices_path: str,
7567
espeak_config: EspeakConfig | None = None,
68+
vocab_config: dict | str | None = None,
7669
):
7770
instance = cls.__new__(cls)
7871
instance.sess = session
7972
instance.config = KoKoroConfig(session._model_path, voices_path, espeak_config)
8073
instance.config.validate()
8174
instance.voices = np.load(voices_path)
82-
instance.tokenizer = Tokenizer(espeak_config)
75+
76+
vocab = instance._load_vocab(vocab_config)
77+
instance.tokenizer = Tokenizer(espeak_config, vocab=vocab)
8378
return instance
8479

80+
def _load_vocab(self, vocab_config: dict | str | None) -> dict:
81+
"""Load vocabulary from config file or dictionary.
82+
83+
Args:
84+
vocab_config: Path to vocab config file or dictionary containing vocab.
85+
86+
Returns:
87+
Loaded vocabulary dictionary or empty dictionary if no config provided.
88+
"""
89+
90+
if isinstance(vocab_config, str):
91+
with open(vocab_config, "r", encoding="utf-8") as fp:
92+
config = json.load(fp)
93+
return config["vocab"]
94+
if isinstance(vocab_config, dict):
95+
return vocab_config["vocab"]
96+
return {}
97+
8598
def _create_audio(
8699
self, phonemes: str, voice: NDArray[np.float32], speed: float
87100
) -> tuple[NDArray[np.float32], int]:

0 commit comments

Comments
 (0)