Skip to content

Commit a2c03c0

Browse files
committed
auto embeddings
- HF embed by default - node imports fix
1 parent 8dedda9 commit a2c03c0

File tree

7 files changed

+138
-63
lines changed

7 files changed

+138
-63
lines changed

.vscode/launch.json

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"version": "0.2.0",
33
"configurations": [
4+
45
{
56
"name": "Debug run_ui.py",
67
"type": "debugpy",
@@ -10,15 +11,6 @@
1011
"justMyCode": false,
1112
"args": ["--development=true", "-Xfrozen_modules=off"]
1213
},
13-
{
14-
"name": "Debug run_cli.py",
15-
"type": "debugpy",
16-
"request": "launch",
17-
"program": "./run_cli.py",
18-
"console": "integratedTerminal",
19-
"justMyCode": false,
20-
"args": ["--development=true", "-Xfrozen_modules=off"]
21-
},
2214
{
2315
"name": "Debug current file",
2416
"type": "debugpy",

agent.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,21 @@
22
from collections import OrderedDict
33
from dataclasses import dataclass, field
44
from datetime import datetime
5-
import time, importlib, inspect, os, json
6-
import token
75
from typing import Any, Awaitable, Coroutine, Optional, Dict, TypedDict
86
import uuid
97
import models
108

11-
from langchain_core.prompt_values import ChatPromptValue
129
from python.helpers import extract_tools, rate_limiter, files, errors, history, tokens
1310
from python.helpers.print_style import PrintStyle
1411
from langchain_core.prompts import (
1512
ChatPromptTemplate,
16-
MessagesPlaceholder,
17-
HumanMessagePromptTemplate,
18-
StringPromptTemplate,
1913
)
20-
from langchain_core.prompts.image import ImagePromptTemplate
2114
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
22-
from langchain_core.language_models.chat_models import BaseChatModel
23-
from langchain_core.language_models.llms import BaseLLM
24-
from langchain_core.embeddings import Embeddings
15+
2516
import python.helpers.log as Log
2617
from python.helpers.dirty_json import DirtyJson
2718
from python.helpers.defer import DeferredTask
2819
from typing import Callable
29-
from python.helpers.history import OutputMessage
3020
from python.helpers.localization import Localization
3121

3222

docker/run/fs/exe/node_eval.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ const Module = require('module');
77
// Enhance `require` to search CWD first, then globally
88
function customRequire(moduleName) {
99
try {
10-
// Try resolving from CWD's node_modules
11-
const cwdPath = path.resolve(process.cwd(), 'node_modules', moduleName);
10+
// Try resolving from CWD's node_modules using Node's require.resolve
11+
const cwdPath = require.resolve(moduleName, { paths: [path.join(process.cwd(), 'node_modules')] });
12+
// console.log("resolved path:", cwdPath);
1213
return require(cwdPath);
1314
} catch (cwdErr) {
1415
try {

preload.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from python.helpers import runtime, whisper, settings
33
from python.helpers.print_style import PrintStyle
4+
import models
45

56
PrintStyle().print("Running preload...")
67
runtime.initialize()
@@ -10,12 +11,31 @@ async def preload():
1011
try:
1112
set = settings.get_default_settings()
1213

14+
# preload whisper model
15+
async def preload_whisper():
16+
try:
17+
return await whisper.preload(set["stt_model_size"])
18+
except Exception as e:
19+
PrintStyle().error(f"Error in preload_whisper: {e}")
20+
21+
# preload embedding model
22+
async def preload_embedding():
23+
if set["embed_model_provider"] == models.ModelProvider.HUGGINGFACE.name:
24+
try:
25+
emb_mod = models.get_huggingface_embedding(set["embed_model_name"])
26+
emb_txt = await emb_mod.aembed_query("test")
27+
return emb_txt
28+
except Exception as e:
29+
PrintStyle().error(f"Error in preload_embedding: {e}")
30+
31+
1332
# async tasks to preload
14-
tasks = [whisper.preload(set["stt_model_size"])]
33+
tasks = [preload_whisper(), preload_embedding()]
1534

16-
return asyncio.gather(*tasks, return_exceptions=True)
35+
await asyncio.gather(*tasks, return_exceptions=True)
36+
PrintStyle().print("Preload completed")
1737
except Exception as e:
18-
PrintStyle().print(f"Error in preload: {e}")
38+
PrintStyle().error(f"Error in preload: {e}")
1939

2040

2141
# preload transcription model

python/helpers/files.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,8 @@ def move_file(relative_path: str, new_path: str):
285285
new_abs_path = get_abs_path(new_path)
286286
os.makedirs(os.path.dirname(new_abs_path), exist_ok=True)
287287
os.rename(abs_path, new_abs_path)
288+
289+
def safe_file_name(filename:str)-> str:
290+
# Replace any character that's not alphanumeric, dash, underscore, or dot with underscore
291+
import re
292+
return re.sub(r'[^a-zA-Z0-9-._]', '_', filename)

python/helpers/memory.py

Lines changed: 88 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from python.helpers import knowledge_import
2424
from python.helpers.log import Log, LogItem
2525
from enum import Enum
26-
from agent import Agent
26+
from agent import Agent, ModelConfig
2727
import models
2828

2929

@@ -36,6 +36,9 @@ def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
3636
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
3737
return self.get_by_ids(ids)
3838

39+
def get_all_docs(self):
40+
return self.docstore._dict # type: ignore
41+
3942

4043
class Memory:
4144

@@ -55,14 +58,9 @@ async def get(agent: Agent):
5558
type="util",
5659
heading=f"Initializing VectorDB in '/{memory_subdir}'",
5760
)
58-
db = Memory.initialize(
61+
db, created = Memory.initialize(
5962
log_item,
60-
models.get_model(
61-
models.ModelType.EMBEDDING,
62-
agent.config.embeddings_model.provider,
63-
agent.config.embeddings_model.name,
64-
**agent.config.embeddings_model.kwargs,
65-
),
63+
agent.config.embeddings_model,
6664
memory_subdir,
6765
False,
6866
)
@@ -90,10 +88,10 @@ async def reload(agent: Agent):
9088
@staticmethod
9189
def initialize(
9290
log_item: LogItem | None,
93-
embeddings_model: Embeddings,
91+
model_config: ModelConfig,
9492
memory_subdir: str,
9593
in_memory=False,
96-
) -> MyFaiss:
94+
) -> tuple[MyFaiss, bool]:
9795

9896
PrintStyle.standard("Initializing VectorDB...")
9997

@@ -114,20 +112,26 @@ def initialize(
114112
os.makedirs(em_dir, exist_ok=True)
115113
store = LocalFileStore(em_dir)
116114

115+
embeddings_model = models.get_model(
116+
models.ModelType.EMBEDDING,
117+
model_config.provider,
118+
model_config.name,
119+
**model_config.kwargs,
120+
)
121+
embeddings_model_id = files.safe_file_name(
122+
model_config.provider.name + "_" + model_config.name
123+
)
124+
117125
# here we setup the embeddings model with the chosen cache storage
118126
embedder = CacheBackedEmbeddings.from_bytes_store(
119-
embeddings_model,
120-
store,
121-
namespace=getattr(
122-
embeddings_model,
123-
"model",
124-
getattr(embeddings_model, "model_name", "default"),
125-
),
127+
embeddings_model, store, namespace=embeddings_model_id
126128
)
127129

128-
# self.db = Chroma(
129-
# embedding_function=self.embedder,
130-
# persist_directory=db_dir)
130+
# initial DB and docs variables
131+
db: MyFaiss | None = None
132+
docs: dict[str, Document] | None = None
133+
134+
created = False
131135

132136
# if db folder exists and is not empty:
133137
if os.path.exists(db_dir) and files.exists(db_dir, "index.faiss"):
@@ -138,8 +142,27 @@ def initialize(
138142
distance_strategy=DistanceStrategy.COSINE,
139143
# normalize_L2=True,
140144
relevance_score_fn=Memory._cosine_normalizer,
141-
)
142-
else:
145+
) # type: ignore
146+
147+
# if there is a mismatch in embeddings used, re-index the whole DB
148+
emb_ok = False
149+
emb_set_file = files.get_abs_path(db_dir, "embedding.json")
150+
if files.exists(emb_set_file):
151+
embedding_set = json.loads(files.read_file(emb_set_file))
152+
if (
153+
embedding_set["model_provider"] == model_config.provider.name
154+
and embedding_set["model_name"] == model_config.name
155+
):
156+
# model matches
157+
emb_ok = True
158+
159+
# re-index - create new DB and insert existing docs
160+
if db and not emb_ok:
161+
docs = db.get_all_docs()
162+
db = None
163+
164+
# DB not loaded, create one
165+
if not db:
143166
index = faiss.IndexFlatIP(len(embedder.embed_query("example")))
144167

145168
db = MyFaiss(
@@ -151,7 +174,31 @@ def initialize(
151174
# normalize_L2=True,
152175
relevance_score_fn=Memory._cosine_normalizer,
153176
)
154-
return db # type: ignore
177+
178+
# insert docs if reindexing
179+
if docs:
180+
PrintStyle.standard("Indexing memories...")
181+
if log_item:
182+
log_item.stream(progress="\nIndexing memories")
183+
db.add_documents(documents=list(docs.values()), ids=list(docs.keys()))
184+
185+
# save DB
186+
Memory._save_db_file(db, memory_subdir)
187+
# save meta file
188+
meta_file_path = files.get_abs_path(db_dir, "embedding.json")
189+
files.write_file(
190+
meta_file_path,
191+
json.dumps(
192+
{
193+
"model_provider": model_config.provider.name,
194+
"model_name": model_config.name,
195+
}
196+
),
197+
)
198+
199+
created = True
200+
201+
return db, created
155202

156203
def __init__(
157204
self,
@@ -243,9 +290,10 @@ async def search_similarity_threshold(
243290
):
244291
comparator = Memory._get_comparator(filter) if filter else None
245292

246-
#rate limiter
293+
# rate limiter
247294
await self.agent.rate_limiter(
248-
model_config=self.agent.config.embeddings_model, input=query)
295+
model_config=self.agent.config.embeddings_model, input=query
296+
)
249297

250298
return await self.db.asearch(
251299
query,
@@ -309,25 +357,30 @@ async def insert_documents(self, docs: list[Document]):
309357
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
310358
timestamp = self.get_timestamp()
311359

312-
313360
if ids:
314361
for doc, id in zip(docs, ids):
315362
doc.metadata["id"] = id # add ids to documents metadata
316363
doc.metadata["timestamp"] = timestamp # add timestamp
317364
if not doc.metadata.get("area", ""):
318365
doc.metadata["area"] = Memory.Area.MAIN.value
319-
320-
#rate limiter
366+
367+
# rate limiter
321368
docs_txt = "".join(self.format_docs_plain(docs))
322369
await self.agent.rate_limiter(
323-
model_config=self.agent.config.embeddings_model, input=docs_txt)
370+
model_config=self.agent.config.embeddings_model, input=docs_txt
371+
)
324372

325373
self.db.add_documents(documents=docs, ids=ids)
326374
self._save_db() # persist
327375
return ids
328376

329377
def _save_db(self):
330-
self.db.save_local(folder_path=self._abs_db_dir(self.memory_subdir))
378+
Memory._save_db_file(self.db, self.memory_subdir)
379+
380+
@staticmethod
381+
def _save_db_file(db: MyFaiss, memory_subdir: str):
382+
abs_dir = Memory._abs_db_dir(memory_subdir)
383+
db.save_local(folder_path=abs_dir)
331384

332385
@staticmethod
333386
def _get_comparator(condition: str):
@@ -382,3 +435,8 @@ def get_custom_knowledge_subdir_abs(agent: Agent) -> str:
382435
if dir != "default":
383436
return files.get_abs_path("knowledge", dir)
384437
raise Exception("No custom knowledge subdir set")
438+
439+
440+
def reload():
441+
# clear the memory index, this will force all DBs to reload
442+
Memory.index = {}

0 commit comments

Comments
 (0)