6
6
from collections import OrderedDict , Counter
7
7
from annoy import AnnoyIndex
8
8
from .data import pad_sequences , df_to_dict
9
-
9
+ from pymilvus import Collection , CollectionSchema , DataType , FieldSchema , connections , utility
10
10
11
11
def gen_model_input (df , user_profile , user_col , item_profile , item_col , seq_max_len , padding = 'pre' , truncating = 'pre' ):
12
12
#merge user_profile and item_profile, pad history seuence feature
@@ -187,6 +187,70 @@ def query(self, v, n):
187
187
def __str__ (self ):
188
188
return 'Annoy(n_trees=%d, search_k=%d)' % (self ._n_trees , self ._search_k )
189
189
190
+
191
+ class Milvus (object ):
192
+ """Vector matching by Milvus.
193
+
194
+ Args:
195
+ dim (int): embedding dim
196
+ host (str): host address of Milvus
197
+ port (str): port of Milvus
198
+ """
199
+
200
+ def __init__ (self , dim = 64 , host = "localhost" , port = "19530" ):
201
+ print ("Start connecting to Milvus" )
202
+ connections .connect ("default" , host = host , port = port )
203
+ self .dim = dim
204
+ has = utility .has_collection ("rechub" )
205
+ #print(f"Does collection rechub exist? {has}")
206
+ if has :
207
+ utility .drop_collection ("rechub" )
208
+ # Create collection
209
+ fields = [
210
+ FieldSchema (name = "id" , dtype = DataType .INT64 , is_primary = True ),
211
+ FieldSchema (name = "embeddings" , dtype = DataType .FLOAT_VECTOR , dim = dim ),
212
+ ]
213
+ schema = CollectionSchema (fields = fields )
214
+ self .milvus = Collection ("rechub" , schema = schema )
215
+
216
+ def fit (self , X ):
217
+ if torch .is_tensor (X ):
218
+ X = X .cpu ().numpy ()
219
+ self .milvus .release ()
220
+ entities = [[i for i in range (len (X ))], X ]
221
+ self .milvus .insert (entities )
222
+ print (
223
+ f"Number of entities in Milvus: { self .milvus .num_entities } "
224
+ ) # check the num_entites
225
+
226
+ index = {
227
+ "index_type" : "IVF_FLAT" ,
228
+ "metric_type" : "L2" ,
229
+ "params" : {"nlist" : 128 },
230
+ }
231
+ self .milvus .create_index ("embeddings" , index )
232
+
233
+ @staticmethod
234
+ def process_result (results ):
235
+ idx_list = []
236
+ score_list = []
237
+ for r in results :
238
+ temp_idx_list = []
239
+ temp_score_list = []
240
+ for i in range (len (r )):
241
+ temp_idx_list .append (r [i ].id )
242
+ temp_score_list .append (r [i ].distance )
243
+ idx_list .append (temp_idx_list )
244
+ score_list .append (temp_score_list )
245
+ return idx_list , score_list
246
+
247
+ def query (self , v , n ):
248
+ if torch .is_tensor (v ):
249
+ v = v .cpu ().numpy ().reshape (- 1 , self .dim )
250
+ self .milvus .load ()
251
+ search_params = {"metric_type" : "L2" , "params" : {"nprobe" : 16 }}
252
+ results = self .milvus .search (v , "embeddings" , search_params , n )
253
+ return self .process_result (results )
190
254
191
255
#annoy = Annoy(n_trees=10)
192
- #annoy.fit(item_embs)
256
+ #annoy.fit(item_embs)
0 commit comments