Skip to content

Commit 275eb3b

Browse files
authored
Merge pull request #69 from rmusser01/dev
Bugfix for copying prompts
2 parents b7df503 + 5238f61 commit 275eb3b

File tree

7 files changed

+2217
-26
lines changed

7 files changed

+2217
-26
lines changed

tldw_chatbook/Embeddings/Chroma_Lib.py

Lines changed: 968 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
# Embeddings_Lib.py
2+
#
3+
from __future__ import annotations
4+
#
5+
"""Adaptive embedding factory for desktop / TUI apps – asynchronous‑ready, logging‑aware.
6+
7+
This version is a thread-safe, robust implementation that addresses race conditions
8+
and resource management issues found in previous iterations.
9+
10+
Key Features & Fixes in this Version
11+
────────────────────────────────────
12+
• **Thread-Safe & Race-Condition-Free**: The core embedding logic is now fully
13+
thread-safe, preventing race conditions where a model could be evicted by one
14+
thread while being used by another.
15+
• **Correct Resource Management**: Eviction now occurs *before* loading a new
16+
model, ensuring the `max_cached` limit is never exceeded, preventing
17+
potential out-of-memory errors.
18+
• **Improved Typing**: The internal cache and Pydantic configurations use
19+
`TypedDict` and discriminated unions for better static analysis and maintainability.
20+
• **Async Facade**: `await factory.async_embed(...)` uses `asyncio.to_thread`
21+
so the UI never blocks.
22+
• **Structured Logging**: Provides insight into cache hits, loads, and evictions.
23+
• **Pluggable & High-Quality Pooling**: Defaults to a proper masked mean pooling
24+
strategy with L2 normalization for superior embedding quality.
25+
• **Prefetch / Warm‑up**: `factory.prefetch(...)` downloads weights on-demand.
26+
"""
27+
#
28+
# Imports
29+
import asyncio
30+
import logging
31+
import random
32+
import threading
33+
import time
34+
from collections import OrderedDict
35+
from typing import (Any, Annotated, Callable, Dict, List, Literal, Optional,
36+
Protocol, TypedDict, Union)
37+
#
38+
# Third-Party Libraries
39+
import numpy as np
40+
import requests
41+
import torch
42+
from pydantic import BaseModel, Field, field_validator
43+
from torch import Tensor
44+
from torch.nn.functional import normalize
45+
from transformers import AutoModel, AutoTokenizer
46+
#
47+
# Local Imports
48+
#
49+
########################################################################################################################
50+
#
51+
__all__ = ["EmbeddingFactory", "EmbeddingConfigSchema"]
52+
#
53+
LOGGER = logging.getLogger("embeddings_lib")
54+
LOGGER.addHandler(logging.NullHandler())
55+
#
56+
###############################################################################
57+
# Configuration schema (with Discriminated Union)
58+
###############################################################################
59+
60+
###############################################################################
61+
# Configuration schema (with Discriminated Union)
62+
###############################################################################
63+
64+
PoolingFn = Callable[[Tensor, Tensor], Tensor]
65+
66+
67+
def _masked_mean(last_hidden: Tensor, attn: Tensor) -> Tensor:
68+
"""Default pooling: mean of vectors where attention_mask is 1."""
69+
mask = attn.unsqueeze(-1).type_as(last_hidden)
70+
summed = (last_hidden * mask).sum(dim=1)
71+
lengths = mask.sum(dim=1).clamp(min=1e-9)
72+
avg = summed / lengths
73+
return normalize(avg, p=2, dim=1)
74+
75+
76+
class HFModelCfg(BaseModel):
77+
provider: Literal["huggingface"] = "huggingface"
78+
model_name_or_path: str
79+
trust_remote_code: bool = False
80+
max_length: int = 512
81+
device: Optional[str] = None
82+
batch_size: int = 32
83+
pooling: Optional[PoolingFn] = None # default: masked mean
84+
dimension: Optional[int] = None
85+
86+
87+
class OpenAICfg(BaseModel):
88+
provider: Literal["openai"] = "openai"
89+
model_name_or_path: str = "text-embedding-3-small"
90+
api_key: Optional[str] = Field(default=None, repr=False)
91+
dimension: Optional[int] = None
92+
93+
@field_validator("api_key", mode="before")
94+
def _default_api_key(cls, v: str | None) -> str:
95+
if v:
96+
return v
97+
from os import getenv
98+
99+
env = getenv("OPENAI_API_KEY")
100+
if not env:
101+
raise ValueError("OPENAI_API_KEY env-var missing and api_key not set")
102+
return env
103+
104+
105+
# A discriminated union lets Pydantic and type checkers infer the correct model type
106+
ModelCfg = Annotated[
107+
Union[HFModelCfg, OpenAICfg],
108+
Field(discriminator="provider")
109+
]
110+
111+
112+
class EmbeddingConfigSchema(BaseModel):
113+
default_model_id: Optional[str] = None
114+
models: Dict[str, ModelCfg]
115+
116+
117+
###############################################################################
118+
# Provider helpers
119+
###############################################################################
120+
121+
class EmbedFn(Protocol):
122+
"""A protocol defining a callable for embedding texts.
123+
124+
This ensures that any function used as an embedding function adheres to this
125+
specific signature, including the keyword-only `as_list` argument.
126+
"""
127+
def __call__(
128+
self, texts: List[str], *, as_list: bool = False
129+
) -> Union[np.ndarray, List[List[float]]]:
130+
...
131+
132+
class CacheRecord(TypedDict):
133+
"""Strongly-typed structure for a cache entry."""
134+
embed: EmbedFn
135+
close: Optional[Callable[[], None]]
136+
last: float
137+
138+
139+
class _HuggingFaceEmbedder:
140+
"""Wraps HF model/tokenizer; exposes poolable, dtype/device-aware embedding."""
141+
142+
def __init__(self, cfg: HFModelCfg):
143+
try:
144+
self._tok = AutoTokenizer.from_pretrained(
145+
cfg.model_name_or_path, trust_remote_code=cfg.trust_remote_code
146+
)
147+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
148+
self._model = AutoModel.from_pretrained(
149+
cfg.model_name_or_path,
150+
torch_dtype=dtype,
151+
trust_remote_code=cfg.trust_remote_code,
152+
)
153+
# --- [FIX] Added robust error handling for model loading ---
154+
except (OSError, requests.exceptions.RequestException) as e:
155+
raise IOError(
156+
f"Failed to download or load model '{cfg.model_name_or_path}'. "
157+
"Check the model name and your network connection."
158+
) from e
159+
160+
self._device = torch.device(
161+
cfg.device if cfg.device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
162+
)
163+
self._model.to(self._device).eval()
164+
self._max_len = cfg.max_length
165+
self._batch_size = cfg.batch_size
166+
self._pool = cfg.pooling or _masked_mean
167+
168+
@torch.inference_mode()
169+
def _forward(self, inp: Dict[str, Tensor]) -> Tensor:
170+
out = self._model(**inp).last_hidden_state
171+
return self._pool(out, inp["attention_mask"])
172+
173+
def embed(self, texts: List[str], *, as_list: bool = False) -> np.ndarray | List[List[float]]:
174+
vecs: List[Tensor] = []
175+
for i in range(0, len(texts), self._batch_size):
176+
batch = texts[i: i + self._batch_size]
177+
tok = self._tok(
178+
batch,
179+
return_tensors="pt",
180+
padding=True,
181+
truncation=True,
182+
max_length=self._max_len,
183+
)
184+
tok = {k: v.to(self._device) for k, v in tok.items()}
185+
# --- [FIX] Performance: keep tensors on GPU during loop ---
186+
vecs.append(self._forward(tok))
187+
188+
# --- [FIX] Performance: concatenate on GPU, then move to CPU once ---
189+
joined = torch.cat(vecs, dim=0).float().cpu().numpy()
190+
return joined.tolist() if as_list else joined
191+
192+
def close(self) -> None:
193+
del self._model, self._tok
194+
if torch.cuda.is_available():
195+
torch.cuda.empty_cache()
196+
197+
198+
# --------------------------------------------------------------------------
199+
200+
_BACKOFF: tuple[float, ...] = (1, 2, 4, 8)
201+
202+
203+
def _openai_embedder(model: str, api_key: str) -> Callable[[list[str], Any, bool], Any]:
204+
session = requests.Session()
205+
session.headers.update(
206+
{"Authorization": f"Bearer {api_key}", "User-Agent": "EmbeddingsLib/4.0"}
207+
)
208+
209+
def _embed(texts: List[str], *, as_list: bool = False) -> np.ndarray | List[List[float]]:
210+
payload = {"input": texts, "model": model}
211+
for attempt, wait in enumerate(_BACKOFF, 1):
212+
t0 = time.perf_counter()
213+
try:
214+
resp = session.post("https://api.openai.com/v1/embeddings", json=payload, timeout=30)
215+
resp.raise_for_status()
216+
data = resp.json()["data"]
217+
arr = np.asarray([d["embedding"] for d in data], dtype=np.float32)
218+
latency = time.perf_counter() - t0
219+
LOGGER.debug("openai_embed[%s] %d texts in %.3fs", model, len(texts), latency)
220+
return arr.tolist() if as_list else arr
221+
except requests.RequestException as exc:
222+
if attempt == len(_BACKOFF):
223+
LOGGER.error("openai_embed failed after %d retries: %s", attempt, exc)
224+
raise
225+
LOGGER.warning("openai_embed retry %d/%d after %s: %s", attempt, len(_BACKOFF), wait, exc)
226+
time.sleep(wait + random.random())
227+
raise RuntimeError("Exhausted retries in OpenAI embedder") # Should be unreachable
228+
229+
return _embed
230+
231+
232+
###############################################################################
233+
# Factory
234+
###############################################################################
235+
236+
class EmbeddingFactory:
237+
"""Thread‑safe LRU/idle cache with sync & async embedding methods."""
238+
239+
def __init__(
240+
self,
241+
cfg: Dict[str, Any] | EmbeddingConfigSchema,
242+
*,
243+
max_cached: int = 2,
244+
idle_seconds: int = 900,
245+
allow_dynamic_hf: bool = True,
246+
) -> None:
247+
self._cfg = cfg if isinstance(cfg, EmbeddingConfigSchema) else EmbeddingConfigSchema(**cfg)
248+
self._max_cached = max_cached
249+
self._idle = idle_seconds
250+
self._allow_dynamic_hf = allow_dynamic_hf
251+
self._cache: "OrderedDict[str, CacheRecord]" = OrderedDict()
252+
self._lock = threading.Lock()
253+
LOGGER.debug("factory initialised max_cached=%d idle=%ds", max_cached, idle_seconds)
254+
255+
def _get_spec(self, model_id: str) -> ModelCfg:
256+
try:
257+
return self._cfg.models[model_id]
258+
except KeyError:
259+
if self._allow_dynamic_hf:
260+
LOGGER.info("dynamic HF model %s", model_id)
261+
return HFModelCfg(model_name_or_path=model_id, trust_remote_code=False)
262+
raise
263+
264+
@property
265+
def config(self) -> EmbeddingConfigSchema:
266+
return self._cfg
267+
268+
def _build(self, model_id: str) -> CacheRecord:
269+
spec = self._get_spec(model_id)
270+
t0 = time.perf_counter()
271+
if spec.provider == "huggingface":
272+
hf = _HuggingFaceEmbedder(spec)
273+
rec = CacheRecord(embed=hf.embed, close=hf.close, last=time.monotonic())
274+
elif spec.provider == "openai":
275+
fn = _openai_embedder(spec.model_name_or_path, spec.api_key)
276+
rec = CacheRecord(embed=fn, close=None, last=time.monotonic())
277+
else:
278+
raise ValueError(f"Unsupported provider: {spec.provider}")
279+
LOGGER.debug("load %s in %.2fs", model_id, time.perf_counter() - t0)
280+
return rec
281+
282+
def embed(
283+
self, texts: List[str], *, model_id: Optional[str] = None, as_list: bool = False
284+
):
285+
model_id_to_use = model_id or self._cfg.default_model_id
286+
if not model_id_to_use:
287+
raise ValueError("No model_id provided and no default_model_id is set.")
288+
289+
if not texts:
290+
return [] if as_list else np.empty((0, 0), dtype=np.float32)
291+
292+
# --- The lock must be held during the embedding call to prevent use-after-free ---
293+
with self._lock:
294+
# First, check for idle models to evict.
295+
now = time.monotonic()
296+
for mid, rec in list(self._cache.items()):
297+
if now - rec["last"] > self._idle:
298+
LOGGER.debug("idle evict %s", mid)
299+
self._cache.pop(mid)
300+
if rec["close"]:
301+
rec["close"]()
302+
303+
# Now, get the model, building it if it doesn't exist.
304+
rec = self._cache.get(model_id_to_use)
305+
if rec is None:
306+
# If we need to load a model, make space for it *first*.
307+
while len(self._cache) >= self._max_cached:
308+
lru_mid, lru_rec = self._cache.popitem(last=False)
309+
LOGGER.debug("LRU evict %s to make space", lru_mid)
310+
if lru_rec["close"]:
311+
lru_rec["close"]()
312+
rec = self._build(model_id_to_use)
313+
self._cache[model_id_to_use] = rec
314+
315+
# Mark as most recently used and get the function to call.
316+
rec["last"] = time.monotonic()
317+
self._cache.move_to_end(model_id_to_use)
318+
embed_fn = rec["embed"]
319+
320+
# The embedding call itself is now inside the lock. This serializes
321+
# all embedding calls, but guarantees that the model cannot be
322+
# evicted by another thread while it is in use.
323+
t0 = time.perf_counter()
324+
result = embed_fn(texts, as_list=as_list)
325+
LOGGER.debug("embed %s %d texts in %.3fs", model_id_to_use, len(texts), time.perf_counter() - t0)
326+
return result
327+
# The lock is released only after the work is complete.
328+
329+
async def async_embed(
330+
self, texts: List[str], *, model_id: Optional[str] = None, as_list: bool = False
331+
):
332+
"""Non-blocking version of `embed` for use in async contexts."""
333+
return await asyncio.to_thread(self.embed, texts, model_id=model_id, as_list=as_list)
334+
335+
def embed_one(
336+
self, text: str, *, model_id: Optional[str] = None, as_list: bool = False
337+
):
338+
vecs = self.embed([text], model_id=model_id, as_list=as_list)
339+
return vecs[0] if as_list else vecs.squeeze(0)
340+
341+
async def async_embed_one(
342+
self, text: str, *, model_id: Optional[str] = None, as_list: bool = False
343+
):
344+
vecs = await self.async_embed([text], model_id=model_id, as_list=as_list)
345+
return vecs[0] if as_list else vecs.squeeze(0)
346+
347+
def prefetch(self, model_ids: List[str]):
348+
"""Download / load given model ids in advance (bypasses eviction)."""
349+
for mid in model_ids:
350+
with self._lock:
351+
if mid in self._cache:
352+
continue
353+
# Note: This can temporarily exceed max_cached, which is acceptable for a startup operation.
354+
self._cache[mid] = self._build(mid)
355+
LOGGER.info("prefetched %s", mid)
356+
357+
def close(self) -> None:
358+
"""Close all models and clear the cache."""
359+
with self._lock:
360+
for mid, rec in self._cache.items():
361+
if rec["close"]:
362+
rec["close"]()
363+
self._cache.clear()
364+
LOGGER.debug("factory closed")
365+
366+
def __enter__(self):
367+
return self
368+
369+
def __exit__(self, exc_type, exc, tb):
370+
self.close()
371+
372+
#
373+
# End of Embeddings_Lib.py
374+
########################################################################################################################

0 commit comments

Comments
 (0)