milvus.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import time
  3. from loguru import logger
  4. from pymilvus import connections, Collection, SearchResult
  5. from app.core.config import settings
  6. class Milvus:
  7. """
  8. Open-source vector database for unstructured data.
  9. """
  10. def __init__(self) -> None:
  11. self._host = os.getenv("MILVUS_HOST", settings.MILVUS_HOST)
  12. self._port = os.getenv("MILVUS_PORT", settings.MILVUS_PORT)
  13. async def query_embedding_by_pk(
  14. self,
  15. collection_name: str,
  16. primary_key_name: str,
  17. pk: int,
  18. output_field: str = "embeddings",
  19. ) -> list:
  20. start = time.perf_counter()
  21. connections.connect("default", host=self._host, port=self._port)
  22. logger.debug(time.perf_counter() - start)
  23. collection = Collection(name=collection_name)
  24. start = time.perf_counter()
  25. collection.load()
  26. logger.debug(time.perf_counter() - start)
  27. expr = f"{primary_key_name} in [{pk}]"
  28. start = time.perf_counter()
  29. res = collection.query(expr=expr, output_fields=[output_field])
  30. logger.debug(time.perf_counter() - start)
  31. try:
  32. emb = res[0].get(output_field)
  33. except [KeyError, IndexError] as e:
  34. emb = []
  35. logger.error(f"Can't find embedding by {pk}, reason: {e}")
  36. start = time.perf_counter()
  37. collection.release()
  38. # connections.disconnect("default")
  39. logger.debug(time.perf_counter() - start)
  40. return emb
  41. async def search(
  42. self,
  43. vec_list: list,
  44. collection_name: str,
  45. field_name: str,
  46. limit: int,
  47. ) -> list:
  48. connections.connect("default", host=self._host, port=self._port)
  49. collection = Collection(name=collection_name)
  50. collection.load()
  51. SEARCH_PARAM = {
  52. "metric_type": "L2",
  53. }
  54. start = time.perf_counter()
  55. res = collection.search(
  56. vec_list,
  57. field_name,
  58. param=SEARCH_PARAM,
  59. limit=limit,
  60. )
  61. logger.debug(time.perf_counter() - start)
  62. pk_list = []
  63. if isinstance(res, SearchResult):
  64. for hits in res:
  65. for hit in hits.ids:
  66. pk_list.append(hit)
  67. collection.release()
  68. # connections.disconnect("default")
  69. return pk_list