milvus_helpers.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import os
  2. import sys
  3. from loguru import logger
  4. from pymilvus import (
  5. connections,
  6. Collection,
  7. CollectionSchema,
  8. DataType,
  9. FieldSchema,
  10. utility,
  11. )
  12. from app.core.config import settings
  13. class MilvusHelper:
  14. def __init__(self):
  15. try:
  16. self.collection = None
  17. connections.connect(
  18. host=os.getenv("MILVUS_HOST", settings.MILVUS_HOST),
  19. port=os.getenv("MILVUS_PORT", settings.MILVUS_PORT),
  20. )
  21. logger.debug(
  22. f"Successfully connect to Milvus with IP: "
  23. f"{os.getenv('MILVUS_HOST', settings.MILVUS_HOST)} "
  24. f"and PORT: {os.getenv('MILVUS_PORT', settings.MILVUS_PORT)}"
  25. )
  26. except Exception as e:
  27. logger.error(f"Failed to connect Milvus: {e}")
  28. sys.exit(1)
  29. def set_collection(self, collection_name: str):
  30. try:
  31. if self.has_collection(collection_name):
  32. self.collection = Collection(name=collection_name)
  33. else:
  34. raise Exception(f"There has no collection named: {collection_name}")
  35. except Exception as e:
  36. logger.error(f"Failed to search data to Milvus: {e}")
  37. sys.exit(1)
  38. # Return if Milvus has the collection
  39. @staticmethod
  40. def has_collection(collection_name: str):
  41. try:
  42. status = utility.has_collection(collection_name)
  43. return status
  44. except Exception as e:
  45. logger.error(f"Failed to load data to Milvus: {e}")
  46. sys.exit(1)
  47. # Create milvus collection if not exists
  48. async def create_collection(self, collection_name: str):
  49. try:
  50. if not self.has_collection(collection_name):
  51. field1 = FieldSchema(
  52. name="pk",
  53. dtype=DataType.INT64,
  54. descrition="int64",
  55. is_primary=True,
  56. auto_id=True,
  57. )
  58. field2 = FieldSchema(
  59. name="embeddings",
  60. dtype=DataType.FLOAT_VECTOR,
  61. descrition="float vector",
  62. dim=settings.VECTOR_DIMENSION,
  63. is_primary=False,
  64. )
  65. schema = CollectionSchema(
  66. fields=[field1, field2],
  67. description="Meeting attendee recommendation",
  68. )
  69. self.collection = Collection(name=collection_name, schema=schema)
  70. logger.debug(f"Create Milvus collection: {self.collection}")
  71. return "OK"
  72. except Exception as e:
  73. logger.error(f"Failed to load data to Milvus: {e}")
  74. sys.exit(1)
  75. # Batch insert vectors to milvus collection
  76. async def insert(self, collection_name: str, vectors: list):
  77. try:
  78. await self.create_collection(collection_name)
  79. self.collection = Collection(name=collection_name)
  80. data = [vectors]
  81. mr = self.collection.insert(data)
  82. ids = mr.primary_keys
  83. self.collection.load()
  84. logger.debug(
  85. f"Insert vectors to Milvus in collection: "
  86. f"{collection_name} with {vectors} rows"
  87. )
  88. return ids
  89. except Exception as e:
  90. logger.error(f"Failed to load data to Milvus: {e}")
  91. sys.exit(1)
  92. # Create FLAT index on milvus collection
  93. async def create_index(self, collection_name: str):
  94. try:
  95. self.set_collection(collection_name)
  96. default_index = {
  97. "index_type": "FLAT",
  98. "metric_type": settings.METRIC_TYPE,
  99. "params": {},
  100. }
  101. status = self.collection.create_index(
  102. field_name="embeddings", index_params=default_index
  103. )
  104. if not status.code:
  105. logger.debug(
  106. f"Successfully create index in collection: "
  107. f"{collection_name} with param: {default_index}"
  108. )
  109. return status
  110. else:
  111. raise Exception(status.message)
  112. except Exception as e:
  113. logger.error(f"Failed to create index: {e}")
  114. sys.exit(1)
  115. # Delete Milvus collection
  116. async def delete_collection(self, collection_name: str):
  117. try:
  118. self.set_collection(collection_name)
  119. self.collection.drop()
  120. logger.debug("Successfully drop collection!")
  121. return "ok"
  122. except Exception as e:
  123. logger.error(f"Failed to drop collection: {e}")
  124. sys.exit(1)
  125. # Search vector in milvus collection
  126. async def search_vectors(self, collection_name: str, vectors: list, top_k: int):
  127. # status = utility.list_collections()
  128. try:
  129. self.set_collection(collection_name)
  130. search_params = {
  131. "metric_type": settings.METRIC_TYPE,
  132. }
  133. try:
  134. res = self.collection.search(
  135. vectors, anns_field="embeddings", param=search_params, limit=top_k
  136. )
  137. except BaseException:
  138. self.collection.load()
  139. res = self.collection.search(
  140. vectors, anns_field="embeddings", param=search_params, limit=top_k
  141. )
  142. pk_list = []
  143. for hits in res:
  144. for hit in hits.ids:
  145. pk_list.append(hit)
  146. return pk_list
  147. except Exception as e:
  148. logger.error(f"Failed to search vectors in Milvus: {e}")
  149. sys.exit(1)
  150. # Get the number of milvus collection
  151. async def count(self, collection_name: str):
  152. try:
  153. self.set_collection(collection_name)
  154. num = self.collection.num_entities
  155. return num
  156. except Exception as e:
  157. logger.error(f"Failed to count vectors in Milvus: {e}")
  158. sys.exit(1)
  159. # Query vector by primary key
  160. async def query_vector_by_pk(self, collection_name: str, pk: int):
  161. try:
  162. self.set_collection(collection_name)
  163. expr = f"pk in [{pk}]"
  164. try:
  165. res = self.collection.query(expr=expr, output_fields=["embeddings"])
  166. except BaseException:
  167. self.collection.load()
  168. res = self.collection.query(expr=expr, output_fields=["embeddings"])
  169. vector = res[0]["embeddings"]
  170. return vector
  171. except Exception as e:
  172. logger.error(f"Failed to query vector in Milvus: {e}")
  173. sys.exit(1)
  174. my_milvus = {}
  175. async def get_milvus_cli() -> None:
  176. MILVUS_CLI = MilvusHelper()
  177. my_milvus.update({"cli": MILVUS_CLI})