23
23
from python .helpers import knowledge_import
24
24
from python .helpers .log import Log , LogItem
25
25
from enum import Enum
26
- from agent import Agent
26
+ from agent import Agent , ModelConfig
27
27
import models
28
28
29
29
@@ -36,6 +36,9 @@ def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
36
36
async def aget_by_ids (self , ids : Sequence [str ], / ) -> List [Document ]:
37
37
return self .get_by_ids (ids )
38
38
39
+ def get_all_docs (self ):
40
+ return self .docstore ._dict # type: ignore
41
+
39
42
40
43
class Memory :
41
44
@@ -55,14 +58,9 @@ async def get(agent: Agent):
55
58
type = "util" ,
56
59
heading = f"Initializing VectorDB in '/{ memory_subdir } '" ,
57
60
)
58
- db = Memory .initialize (
61
+ db , created = Memory .initialize (
59
62
log_item ,
60
- models .get_model (
61
- models .ModelType .EMBEDDING ,
62
- agent .config .embeddings_model .provider ,
63
- agent .config .embeddings_model .name ,
64
- ** agent .config .embeddings_model .kwargs ,
65
- ),
63
+ agent .config .embeddings_model ,
66
64
memory_subdir ,
67
65
False ,
68
66
)
@@ -90,10 +88,10 @@ async def reload(agent: Agent):
90
88
@staticmethod
91
89
def initialize (
92
90
log_item : LogItem | None ,
93
- embeddings_model : Embeddings ,
91
+ model_config : ModelConfig ,
94
92
memory_subdir : str ,
95
93
in_memory = False ,
96
- ) -> MyFaiss :
94
+ ) -> tuple [ MyFaiss , bool ] :
97
95
98
96
PrintStyle .standard ("Initializing VectorDB..." )
99
97
@@ -114,20 +112,26 @@ def initialize(
114
112
os .makedirs (em_dir , exist_ok = True )
115
113
store = LocalFileStore (em_dir )
116
114
115
+ embeddings_model = models .get_model (
116
+ models .ModelType .EMBEDDING ,
117
+ model_config .provider ,
118
+ model_config .name ,
119
+ ** model_config .kwargs ,
120
+ )
121
+ embeddings_model_id = files .safe_file_name (
122
+ model_config .provider .name + "_" + model_config .name
123
+ )
124
+
117
125
# here we setup the embeddings model with the chosen cache storage
118
126
embedder = CacheBackedEmbeddings .from_bytes_store (
119
- embeddings_model ,
120
- store ,
121
- namespace = getattr (
122
- embeddings_model ,
123
- "model" ,
124
- getattr (embeddings_model , "model_name" , "default" ),
125
- ),
127
+ embeddings_model , store , namespace = embeddings_model_id
126
128
)
127
129
128
- # self.db = Chroma(
129
- # embedding_function=self.embedder,
130
- # persist_directory=db_dir)
130
+ # initial DB and docs variables
131
+ db : MyFaiss | None = None
132
+ docs : dict [str , Document ] | None = None
133
+
134
+ created = False
131
135
132
136
# if db folder exists and is not empty:
133
137
if os .path .exists (db_dir ) and files .exists (db_dir , "index.faiss" ):
@@ -138,8 +142,27 @@ def initialize(
138
142
distance_strategy = DistanceStrategy .COSINE ,
139
143
# normalize_L2=True,
140
144
relevance_score_fn = Memory ._cosine_normalizer ,
141
- )
142
- else :
145
+ ) # type: ignore
146
+
147
+ # if there is a mismatch in embeddings used, re-index the whole DB
148
+ emb_ok = False
149
+ emb_set_file = files .get_abs_path (db_dir , "embedding.json" )
150
+ if files .exists (emb_set_file ):
151
+ embedding_set = json .loads (files .read_file (emb_set_file ))
152
+ if (
153
+ embedding_set ["model_provider" ] == model_config .provider .name
154
+ and embedding_set ["model_name" ] == model_config .name
155
+ ):
156
+ # model matches
157
+ emb_ok = True
158
+
159
+ # re-index - create new DB and insert existing docs
160
+ if db and not emb_ok :
161
+ docs = db .get_all_docs ()
162
+ db = None
163
+
164
+ # DB not loaded, create one
165
+ if not db :
143
166
index = faiss .IndexFlatIP (len (embedder .embed_query ("example" )))
144
167
145
168
db = MyFaiss (
@@ -151,7 +174,31 @@ def initialize(
151
174
# normalize_L2=True,
152
175
relevance_score_fn = Memory ._cosine_normalizer ,
153
176
)
154
- return db # type: ignore
177
+
178
+ # insert docs if reindexing
179
+ if docs :
180
+ PrintStyle .standard ("Indexing memories..." )
181
+ if log_item :
182
+ log_item .stream (progress = "\n Indexing memories" )
183
+ db .add_documents (documents = list (docs .values ()), ids = list (docs .keys ()))
184
+
185
+ # save DB
186
+ Memory ._save_db_file (db , memory_subdir )
187
+ # save meta file
188
+ meta_file_path = files .get_abs_path (db_dir , "embedding.json" )
189
+ files .write_file (
190
+ meta_file_path ,
191
+ json .dumps (
192
+ {
193
+ "model_provider" : model_config .provider .name ,
194
+ "model_name" : model_config .name ,
195
+ }
196
+ ),
197
+ )
198
+
199
+ created = True
200
+
201
+ return db , created
155
202
156
203
def __init__ (
157
204
self ,
@@ -243,9 +290,10 @@ async def search_similarity_threshold(
243
290
):
244
291
comparator = Memory ._get_comparator (filter ) if filter else None
245
292
246
- #rate limiter
293
+ # rate limiter
247
294
await self .agent .rate_limiter (
248
- model_config = self .agent .config .embeddings_model , input = query )
295
+ model_config = self .agent .config .embeddings_model , input = query
296
+ )
249
297
250
298
return await self .db .asearch (
251
299
query ,
@@ -309,25 +357,30 @@ async def insert_documents(self, docs: list[Document]):
309
357
ids = [str (uuid .uuid4 ()) for _ in range (len (docs ))]
310
358
timestamp = self .get_timestamp ()
311
359
312
-
313
360
if ids :
314
361
for doc , id in zip (docs , ids ):
315
362
doc .metadata ["id" ] = id # add ids to documents metadata
316
363
doc .metadata ["timestamp" ] = timestamp # add timestamp
317
364
if not doc .metadata .get ("area" , "" ):
318
365
doc .metadata ["area" ] = Memory .Area .MAIN .value
319
-
320
- #rate limiter
366
+
367
+ # rate limiter
321
368
docs_txt = "" .join (self .format_docs_plain (docs ))
322
369
await self .agent .rate_limiter (
323
- model_config = self .agent .config .embeddings_model , input = docs_txt )
370
+ model_config = self .agent .config .embeddings_model , input = docs_txt
371
+ )
324
372
325
373
self .db .add_documents (documents = docs , ids = ids )
326
374
self ._save_db () # persist
327
375
return ids
328
376
329
377
def _save_db (self ):
330
- self .db .save_local (folder_path = self ._abs_db_dir (self .memory_subdir ))
378
+ Memory ._save_db_file (self .db , self .memory_subdir )
379
+
380
+ @staticmethod
381
+ def _save_db_file (db : MyFaiss , memory_subdir : str ):
382
+ abs_dir = Memory ._abs_db_dir (memory_subdir )
383
+ db .save_local (folder_path = abs_dir )
331
384
332
385
@staticmethod
333
386
def _get_comparator (condition : str ):
@@ -382,3 +435,8 @@ def get_custom_knowledge_subdir_abs(agent: Agent) -> str:
382
435
if dir != "default" :
383
436
return files .get_abs_path ("knowledge" , dir )
384
437
raise Exception ("No custom knowledge subdir set" )
438
+
439
+
440
+ def reload ():
441
+ # clear the memory index, this will force all DBs to reload
442
+ Memory .index = {}
0 commit comments