Skip to content

Commit 552bf23

Browse files
authored
Merge pull request #47 from storyandwine/main
add milvus
2 parents 6b06ead + 29d1ac5 commit 552bf23

File tree

2 files changed

+673
-2
lines changed

2 files changed

+673
-2
lines changed

torch_rechub/utils/match.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import OrderedDict, Counter
77
from annoy import AnnoyIndex
88
from .data import pad_sequences, df_to_dict
9-
9+
from pymilvus import Collection,CollectionSchema,DataType,FieldSchema,connections,utility
1010

1111
def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len, padding='pre', truncating='pre'):
1212
#merge user_profile and item_profile, pad history seuence feature
@@ -187,6 +187,70 @@ def query(self, v, n):
187187
def __str__(self):
188188
return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees, self._search_k)
189189

190+
191+
class Milvus(object):
192+
"""Vector matching by Milvus.
193+
194+
Args:
195+
dim (int): embedding dim
196+
host (str): host address of Milvus
197+
port (str): port of Milvus
198+
"""
199+
200+
def __init__(self, dim=64, host="localhost", port="19530"):
201+
print("Start connecting to Milvus")
202+
connections.connect("default", host=host, port=port)
203+
self.dim = dim
204+
has = utility.has_collection("rechub")
205+
#print(f"Does collection rechub exist? {has}")
206+
if has:
207+
utility.drop_collection("rechub")
208+
# Create collection
209+
fields = [
210+
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
211+
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim),
212+
]
213+
schema = CollectionSchema(fields=fields)
214+
self.milvus = Collection("rechub", schema=schema)
215+
216+
def fit(self, X):
217+
if torch.is_tensor(X):
218+
X = X.cpu().numpy()
219+
self.milvus.release()
220+
entities = [[i for i in range(len(X))], X]
221+
self.milvus.insert(entities)
222+
print(
223+
f"Number of entities in Milvus: {self.milvus.num_entities}"
224+
) # check the num_entites
225+
226+
index = {
227+
"index_type": "IVF_FLAT",
228+
"metric_type": "L2",
229+
"params": {"nlist": 128},
230+
}
231+
self.milvus.create_index("embeddings", index)
232+
233+
@staticmethod
234+
def process_result(results):
235+
idx_list = []
236+
score_list = []
237+
for r in results:
238+
temp_idx_list = []
239+
temp_score_list = []
240+
for i in range(len(r)):
241+
temp_idx_list.append(r[i].id)
242+
temp_score_list.append(r[i].distance)
243+
idx_list.append(temp_idx_list)
244+
score_list.append(temp_score_list)
245+
return idx_list, score_list
246+
247+
def query(self, v, n):
248+
if torch.is_tensor(v):
249+
v = v.cpu().numpy().reshape(-1, self.dim)
250+
self.milvus.load()
251+
search_params = {"metric_type": "L2", "params": {"nprobe": 16}}
252+
results = self.milvus.search(v, "embeddings", search_params, n)
253+
return self.process_result(results)
190254

191255
#annoy = Annoy(n_trees=10)
192-
#annoy.fit(item_embs)
256+
#annoy.fit(item_embs)

0 commit comments

Comments
 (0)