|
| 1 | +# Embeddings_Lib.py |
| 2 | +# |
| 3 | +from __future__ import annotations |
| 4 | +# |
| 5 | +"""Adaptive embedding factory for desktop / TUI apps – asynchronous‑ready, logging‑aware. |
| 6 | +
|
| 7 | +This version is a thread-safe, robust implementation that addresses race conditions |
| 8 | +and resource management issues found in previous iterations. |
| 9 | +
|
| 10 | +Key Features & Fixes in this Version |
| 11 | +──────────────────────────────────── |
| 12 | +• **Thread-Safe & Race-Condition-Free**: The core embedding logic is now fully |
| 13 | + thread-safe, preventing race conditions where a model could be evicted by one |
| 14 | + thread while being used by another. |
| 15 | +• **Correct Resource Management**: Eviction now occurs *before* loading a new |
| 16 | + model, ensuring the `max_cached` limit is never exceeded, preventing |
| 17 | + potential out-of-memory errors. |
| 18 | +• **Improved Typing**: The internal cache and Pydantic configurations use |
| 19 | + `TypedDict` and discriminated unions for better static analysis and maintainability. |
| 20 | +• **Async Facade**: `await factory.async_embed(...)` uses `asyncio.to_thread` |
| 21 | + so the UI never blocks. |
| 22 | +• **Structured Logging**: Provides insight into cache hits, loads, and evictions. |
| 23 | +• **Pluggable & High-Quality Pooling**: Defaults to a proper masked mean pooling |
| 24 | + strategy with L2 normalization for superior embedding quality. |
| 25 | +• **Prefetch / Warm‑up**: `factory.prefetch(...)` downloads weights on-demand. |
| 26 | +""" |
| 27 | +# |
| 28 | +# Imports |
| 29 | +import asyncio |
| 30 | +import logging |
| 31 | +import random |
| 32 | +import threading |
| 33 | +import time |
| 34 | +from collections import OrderedDict |
| 35 | +from typing import (Any, Annotated, Callable, Dict, List, Literal, Optional, |
| 36 | + Protocol, TypedDict, Union) |
| 37 | +# |
| 38 | +# Third-Party Libraries |
| 39 | +import numpy as np |
| 40 | +import requests |
| 41 | +import torch |
| 42 | +from pydantic import BaseModel, Field, field_validator |
| 43 | +from torch import Tensor |
| 44 | +from torch.nn.functional import normalize |
| 45 | +from transformers import AutoModel, AutoTokenizer |
| 46 | +# |
| 47 | +# Local Imports |
| 48 | +# |
| 49 | +######################################################################################################################## |
| 50 | +# |
| 51 | +__all__ = ["EmbeddingFactory", "EmbeddingConfigSchema"] |
| 52 | +# |
| 53 | +LOGGER = logging.getLogger("embeddings_lib") |
| 54 | +LOGGER.addHandler(logging.NullHandler()) |
| 55 | +# |
| 56 | +############################################################################### |
| 57 | +# Configuration schema (with Discriminated Union) |
| 58 | +############################################################################### |
| 59 | + |
| 60 | +############################################################################### |
| 61 | +# Configuration schema (with Discriminated Union) |
| 62 | +############################################################################### |
| 63 | + |
| 64 | +PoolingFn = Callable[[Tensor, Tensor], Tensor] |
| 65 | + |
| 66 | + |
| 67 | +def _masked_mean(last_hidden: Tensor, attn: Tensor) -> Tensor: |
| 68 | + """Default pooling: mean of vectors where attention_mask is 1.""" |
| 69 | + mask = attn.unsqueeze(-1).type_as(last_hidden) |
| 70 | + summed = (last_hidden * mask).sum(dim=1) |
| 71 | + lengths = mask.sum(dim=1).clamp(min=1e-9) |
| 72 | + avg = summed / lengths |
| 73 | + return normalize(avg, p=2, dim=1) |
| 74 | + |
| 75 | + |
| 76 | +class HFModelCfg(BaseModel): |
| 77 | + provider: Literal["huggingface"] = "huggingface" |
| 78 | + model_name_or_path: str |
| 79 | + trust_remote_code: bool = False |
| 80 | + max_length: int = 512 |
| 81 | + device: Optional[str] = None |
| 82 | + batch_size: int = 32 |
| 83 | + pooling: Optional[PoolingFn] = None # default: masked mean |
| 84 | + dimension: Optional[int] = None |
| 85 | + |
| 86 | + |
| 87 | +class OpenAICfg(BaseModel): |
| 88 | + provider: Literal["openai"] = "openai" |
| 89 | + model_name_or_path: str = "text-embedding-3-small" |
| 90 | + api_key: Optional[str] = Field(default=None, repr=False) |
| 91 | + dimension: Optional[int] = None |
| 92 | + |
| 93 | + @field_validator("api_key", mode="before") |
| 94 | + def _default_api_key(cls, v: str | None) -> str: |
| 95 | + if v: |
| 96 | + return v |
| 97 | + from os import getenv |
| 98 | + |
| 99 | + env = getenv("OPENAI_API_KEY") |
| 100 | + if not env: |
| 101 | + raise ValueError("OPENAI_API_KEY env-var missing and api_key not set") |
| 102 | + return env |
| 103 | + |
| 104 | + |
| 105 | +# A discriminated union lets Pydantic and type checkers infer the correct model type |
| 106 | +ModelCfg = Annotated[ |
| 107 | + Union[HFModelCfg, OpenAICfg], |
| 108 | + Field(discriminator="provider") |
| 109 | +] |
| 110 | + |
| 111 | + |
| 112 | +class EmbeddingConfigSchema(BaseModel): |
| 113 | + default_model_id: Optional[str] = None |
| 114 | + models: Dict[str, ModelCfg] |
| 115 | + |
| 116 | + |
| 117 | +############################################################################### |
| 118 | +# Provider helpers |
| 119 | +############################################################################### |
| 120 | + |
| 121 | +class EmbedFn(Protocol): |
| 122 | + """A protocol defining a callable for embedding texts. |
| 123 | +
|
| 124 | + This ensures that any function used as an embedding function adheres to this |
| 125 | + specific signature, including the keyword-only `as_list` argument. |
| 126 | + """ |
| 127 | + def __call__( |
| 128 | + self, texts: List[str], *, as_list: bool = False |
| 129 | + ) -> Union[np.ndarray, List[List[float]]]: |
| 130 | + ... |
| 131 | + |
| 132 | +class CacheRecord(TypedDict): |
| 133 | + """Strongly-typed structure for a cache entry.""" |
| 134 | + embed: EmbedFn |
| 135 | + close: Optional[Callable[[], None]] |
| 136 | + last: float |
| 137 | + |
| 138 | + |
| 139 | +class _HuggingFaceEmbedder: |
| 140 | + """Wraps HF model/tokenizer; exposes poolable, dtype/device-aware embedding.""" |
| 141 | + |
| 142 | + def __init__(self, cfg: HFModelCfg): |
| 143 | + try: |
| 144 | + self._tok = AutoTokenizer.from_pretrained( |
| 145 | + cfg.model_name_or_path, trust_remote_code=cfg.trust_remote_code |
| 146 | + ) |
| 147 | + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| 148 | + self._model = AutoModel.from_pretrained( |
| 149 | + cfg.model_name_or_path, |
| 150 | + torch_dtype=dtype, |
| 151 | + trust_remote_code=cfg.trust_remote_code, |
| 152 | + ) |
| 153 | + # --- [FIX] Added robust error handling for model loading --- |
| 154 | + except (OSError, requests.exceptions.RequestException) as e: |
| 155 | + raise IOError( |
| 156 | + f"Failed to download or load model '{cfg.model_name_or_path}'. " |
| 157 | + "Check the model name and your network connection." |
| 158 | + ) from e |
| 159 | + |
| 160 | + self._device = torch.device( |
| 161 | + cfg.device if cfg.device is not None else ("cuda" if torch.cuda.is_available() else "cpu") |
| 162 | + ) |
| 163 | + self._model.to(self._device).eval() |
| 164 | + self._max_len = cfg.max_length |
| 165 | + self._batch_size = cfg.batch_size |
| 166 | + self._pool = cfg.pooling or _masked_mean |
| 167 | + |
| 168 | + @torch.inference_mode() |
| 169 | + def _forward(self, inp: Dict[str, Tensor]) -> Tensor: |
| 170 | + out = self._model(**inp).last_hidden_state |
| 171 | + return self._pool(out, inp["attention_mask"]) |
| 172 | + |
| 173 | + def embed(self, texts: List[str], *, as_list: bool = False) -> np.ndarray | List[List[float]]: |
| 174 | + vecs: List[Tensor] = [] |
| 175 | + for i in range(0, len(texts), self._batch_size): |
| 176 | + batch = texts[i: i + self._batch_size] |
| 177 | + tok = self._tok( |
| 178 | + batch, |
| 179 | + return_tensors="pt", |
| 180 | + padding=True, |
| 181 | + truncation=True, |
| 182 | + max_length=self._max_len, |
| 183 | + ) |
| 184 | + tok = {k: v.to(self._device) for k, v in tok.items()} |
| 185 | + # --- [FIX] Performance: keep tensors on GPU during loop --- |
| 186 | + vecs.append(self._forward(tok)) |
| 187 | + |
| 188 | + # --- [FIX] Performance: concatenate on GPU, then move to CPU once --- |
| 189 | + joined = torch.cat(vecs, dim=0).float().cpu().numpy() |
| 190 | + return joined.tolist() if as_list else joined |
| 191 | + |
| 192 | + def close(self) -> None: |
| 193 | + del self._model, self._tok |
| 194 | + if torch.cuda.is_available(): |
| 195 | + torch.cuda.empty_cache() |
| 196 | + |
| 197 | + |
| 198 | +# -------------------------------------------------------------------------- |
| 199 | + |
| 200 | +_BACKOFF: tuple[float, ...] = (1, 2, 4, 8) |
| 201 | + |
| 202 | + |
| 203 | +def _openai_embedder(model: str, api_key: str) -> Callable[[list[str], Any, bool], Any]: |
| 204 | + session = requests.Session() |
| 205 | + session.headers.update( |
| 206 | + {"Authorization": f"Bearer {api_key}", "User-Agent": "EmbeddingsLib/4.0"} |
| 207 | + ) |
| 208 | + |
| 209 | + def _embed(texts: List[str], *, as_list: bool = False) -> np.ndarray | List[List[float]]: |
| 210 | + payload = {"input": texts, "model": model} |
| 211 | + for attempt, wait in enumerate(_BACKOFF, 1): |
| 212 | + t0 = time.perf_counter() |
| 213 | + try: |
| 214 | + resp = session.post("https://api.openai.com/v1/embeddings", json=payload, timeout=30) |
| 215 | + resp.raise_for_status() |
| 216 | + data = resp.json()["data"] |
| 217 | + arr = np.asarray([d["embedding"] for d in data], dtype=np.float32) |
| 218 | + latency = time.perf_counter() - t0 |
| 219 | + LOGGER.debug("openai_embed[%s] %d texts in %.3fs", model, len(texts), latency) |
| 220 | + return arr.tolist() if as_list else arr |
| 221 | + except requests.RequestException as exc: |
| 222 | + if attempt == len(_BACKOFF): |
| 223 | + LOGGER.error("openai_embed failed after %d retries: %s", attempt, exc) |
| 224 | + raise |
| 225 | + LOGGER.warning("openai_embed retry %d/%d after %s: %s", attempt, len(_BACKOFF), wait, exc) |
| 226 | + time.sleep(wait + random.random()) |
| 227 | + raise RuntimeError("Exhausted retries in OpenAI embedder") # Should be unreachable |
| 228 | + |
| 229 | + return _embed |
| 230 | + |
| 231 | + |
| 232 | +############################################################################### |
| 233 | +# Factory |
| 234 | +############################################################################### |
| 235 | + |
| 236 | +class EmbeddingFactory: |
| 237 | + """Thread‑safe LRU/idle cache with sync & async embedding methods.""" |
| 238 | + |
| 239 | + def __init__( |
| 240 | + self, |
| 241 | + cfg: Dict[str, Any] | EmbeddingConfigSchema, |
| 242 | + *, |
| 243 | + max_cached: int = 2, |
| 244 | + idle_seconds: int = 900, |
| 245 | + allow_dynamic_hf: bool = True, |
| 246 | + ) -> None: |
| 247 | + self._cfg = cfg if isinstance(cfg, EmbeddingConfigSchema) else EmbeddingConfigSchema(**cfg) |
| 248 | + self._max_cached = max_cached |
| 249 | + self._idle = idle_seconds |
| 250 | + self._allow_dynamic_hf = allow_dynamic_hf |
| 251 | + self._cache: "OrderedDict[str, CacheRecord]" = OrderedDict() |
| 252 | + self._lock = threading.Lock() |
| 253 | + LOGGER.debug("factory initialised max_cached=%d idle=%ds", max_cached, idle_seconds) |
| 254 | + |
| 255 | + def _get_spec(self, model_id: str) -> ModelCfg: |
| 256 | + try: |
| 257 | + return self._cfg.models[model_id] |
| 258 | + except KeyError: |
| 259 | + if self._allow_dynamic_hf: |
| 260 | + LOGGER.info("dynamic HF model %s", model_id) |
| 261 | + return HFModelCfg(model_name_or_path=model_id, trust_remote_code=False) |
| 262 | + raise |
| 263 | + |
| 264 | + @property |
| 265 | + def config(self) -> EmbeddingConfigSchema: |
| 266 | + return self._cfg |
| 267 | + |
| 268 | + def _build(self, model_id: str) -> CacheRecord: |
| 269 | + spec = self._get_spec(model_id) |
| 270 | + t0 = time.perf_counter() |
| 271 | + if spec.provider == "huggingface": |
| 272 | + hf = _HuggingFaceEmbedder(spec) |
| 273 | + rec = CacheRecord(embed=hf.embed, close=hf.close, last=time.monotonic()) |
| 274 | + elif spec.provider == "openai": |
| 275 | + fn = _openai_embedder(spec.model_name_or_path, spec.api_key) |
| 276 | + rec = CacheRecord(embed=fn, close=None, last=time.monotonic()) |
| 277 | + else: |
| 278 | + raise ValueError(f"Unsupported provider: {spec.provider}") |
| 279 | + LOGGER.debug("load %s in %.2fs", model_id, time.perf_counter() - t0) |
| 280 | + return rec |
| 281 | + |
| 282 | + def embed( |
| 283 | + self, texts: List[str], *, model_id: Optional[str] = None, as_list: bool = False |
| 284 | + ): |
| 285 | + model_id_to_use = model_id or self._cfg.default_model_id |
| 286 | + if not model_id_to_use: |
| 287 | + raise ValueError("No model_id provided and no default_model_id is set.") |
| 288 | + |
| 289 | + if not texts: |
| 290 | + return [] if as_list else np.empty((0, 0), dtype=np.float32) |
| 291 | + |
| 292 | + # --- The lock must be held during the embedding call to prevent use-after-free --- |
| 293 | + with self._lock: |
| 294 | + # First, check for idle models to evict. |
| 295 | + now = time.monotonic() |
| 296 | + for mid, rec in list(self._cache.items()): |
| 297 | + if now - rec["last"] > self._idle: |
| 298 | + LOGGER.debug("idle evict %s", mid) |
| 299 | + self._cache.pop(mid) |
| 300 | + if rec["close"]: |
| 301 | + rec["close"]() |
| 302 | + |
| 303 | + # Now, get the model, building it if it doesn't exist. |
| 304 | + rec = self._cache.get(model_id_to_use) |
| 305 | + if rec is None: |
| 306 | + # If we need to load a model, make space for it *first*. |
| 307 | + while len(self._cache) >= self._max_cached: |
| 308 | + lru_mid, lru_rec = self._cache.popitem(last=False) |
| 309 | + LOGGER.debug("LRU evict %s to make space", lru_mid) |
| 310 | + if lru_rec["close"]: |
| 311 | + lru_rec["close"]() |
| 312 | + rec = self._build(model_id_to_use) |
| 313 | + self._cache[model_id_to_use] = rec |
| 314 | + |
| 315 | + # Mark as most recently used and get the function to call. |
| 316 | + rec["last"] = time.monotonic() |
| 317 | + self._cache.move_to_end(model_id_to_use) |
| 318 | + embed_fn = rec["embed"] |
| 319 | + |
| 320 | + # The embedding call itself is now inside the lock. This serializes |
| 321 | + # all embedding calls, but guarantees that the model cannot be |
| 322 | + # evicted by another thread while it is in use. |
| 323 | + t0 = time.perf_counter() |
| 324 | + result = embed_fn(texts, as_list=as_list) |
| 325 | + LOGGER.debug("embed %s %d texts in %.3fs", model_id_to_use, len(texts), time.perf_counter() - t0) |
| 326 | + return result |
| 327 | + # The lock is released only after the work is complete. |
| 328 | + |
| 329 | + async def async_embed( |
| 330 | + self, texts: List[str], *, model_id: Optional[str] = None, as_list: bool = False |
| 331 | + ): |
| 332 | + """Non-blocking version of `embed` for use in async contexts.""" |
| 333 | + return await asyncio.to_thread(self.embed, texts, model_id=model_id, as_list=as_list) |
| 334 | + |
| 335 | + def embed_one( |
| 336 | + self, text: str, *, model_id: Optional[str] = None, as_list: bool = False |
| 337 | + ): |
| 338 | + vecs = self.embed([text], model_id=model_id, as_list=as_list) |
| 339 | + return vecs[0] if as_list else vecs.squeeze(0) |
| 340 | + |
| 341 | + async def async_embed_one( |
| 342 | + self, text: str, *, model_id: Optional[str] = None, as_list: bool = False |
| 343 | + ): |
| 344 | + vecs = await self.async_embed([text], model_id=model_id, as_list=as_list) |
| 345 | + return vecs[0] if as_list else vecs.squeeze(0) |
| 346 | + |
| 347 | + def prefetch(self, model_ids: List[str]): |
| 348 | + """Download / load given model ids in advance (bypasses eviction).""" |
| 349 | + for mid in model_ids: |
| 350 | + with self._lock: |
| 351 | + if mid in self._cache: |
| 352 | + continue |
| 353 | + # Note: This can temporarily exceed max_cached, which is acceptable for a startup operation. |
| 354 | + self._cache[mid] = self._build(mid) |
| 355 | + LOGGER.info("prefetched %s", mid) |
| 356 | + |
| 357 | + def close(self) -> None: |
| 358 | + """Close all models and clear the cache.""" |
| 359 | + with self._lock: |
| 360 | + for mid, rec in self._cache.items(): |
| 361 | + if rec["close"]: |
| 362 | + rec["close"]() |
| 363 | + self._cache.clear() |
| 364 | + LOGGER.debug("factory closed") |
| 365 | + |
| 366 | + def __enter__(self): |
| 367 | + return self |
| 368 | + |
| 369 | + def __exit__(self, exc_type, exc, tb): |
| 370 | + self.close() |
| 371 | + |
| 372 | +# |
| 373 | +# End of Embeddings_Lib.py |
| 374 | +######################################################################################################################## |
0 commit comments