import os import time from loguru import logger from pymilvus import connections, Collection, SearchResult from app.core.config import settings class Milvus: """ Open-source vector database for unstructured data. """ def __init__(self) -> None: self._host = os.getenv("MILVUS_HOST", settings.MILVUS_HOST) self._port = os.getenv("MILVUS_PORT", settings.MILVUS_PORT) async def query_embedding_by_pk( self, collection_name: str, primary_key_name: str, pk: int, output_field: str = "embeddings", ) -> list: start = time.perf_counter() connections.connect("default", host=self._host, port=self._port) logger.debug(time.perf_counter() - start) collection = Collection(name=collection_name) start = time.perf_counter() collection.load() logger.debug(time.perf_counter() - start) expr = f"{primary_key_name} in [{pk}]" start = time.perf_counter() res = collection.query(expr=expr, output_fields=[output_field]) logger.debug(time.perf_counter() - start) try: emb = res[0].get(output_field) except [KeyError, IndexError] as e: emb = [] logger.error(f"Can't find embedding by {pk}, reason: {e}") start = time.perf_counter() collection.release() # connections.disconnect("default") logger.debug(time.perf_counter() - start) return emb async def search( self, vec_list: list, collection_name: str, field_name: str, limit: int, ) -> list: connections.connect("default", host=self._host, port=self._port) collection = Collection(name=collection_name) collection.load() SEARCH_PARAM = { "metric_type": "L2", } start = time.perf_counter() res = collection.search( vec_list, field_name, param=SEARCH_PARAM, limit=limit, ) logger.debug(time.perf_counter() - start) pk_list = [] if isinstance(res, SearchResult): for hits in res: for hit in hits.ids: pk_list.append(hit) collection.release() # connections.disconnect("default") return pk_list