123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import os
- import time
- from loguru import logger
- from pymilvus import connections, Collection, SearchResult
- from app.core.config import settings
- class Milvus:
- """
- Open-source vector database for unstructured data.
- """
- def __init__(self) -> None:
- self._host = os.getenv("MILVUS_HOST", settings.MILVUS_HOST)
- self._port = os.getenv("MILVUS_PORT", settings.MILVUS_PORT)
- async def query_embedding_by_pk(
- self,
- collection_name: str,
- primary_key_name: str,
- pk: int,
- output_field: str = "embeddings",
- ) -> list:
- start = time.perf_counter()
- connections.connect("default", host=self._host, port=self._port)
- logger.debug(time.perf_counter() - start)
- collection = Collection(name=collection_name)
- start = time.perf_counter()
- collection.load()
- logger.debug(time.perf_counter() - start)
- expr = f"{primary_key_name} in [{pk}]"
- start = time.perf_counter()
- res = collection.query(expr=expr, output_fields=[output_field])
- logger.debug(time.perf_counter() - start)
- try:
- emb = res[0].get(output_field)
- except [KeyError, IndexError] as e:
- emb = []
- logger.error(f"Can't find embedding by {pk}, reason: {e}")
- start = time.perf_counter()
- collection.release()
- # connections.disconnect("default")
- logger.debug(time.perf_counter() - start)
- return emb
- async def search(
- self,
- vec_list: list,
- collection_name: str,
- field_name: str,
- limit: int,
- ) -> list:
- connections.connect("default", host=self._host, port=self._port)
- collection = Collection(name=collection_name)
- collection.load()
- SEARCH_PARAM = {
- "metric_type": "L2",
- }
- start = time.perf_counter()
- res = collection.search(
- vec_list,
- field_name,
- param=SEARCH_PARAM,
- limit=limit,
- )
- logger.debug(time.perf_counter() - start)
- pk_list = []
- if isinstance(res, SearchResult):
- for hits in res:
- for hit in hits.ids:
- pk_list.append(hit)
- collection.release()
- # connections.disconnect("default")
- return pk_list
|