import os import sys from loguru import logger from pymilvus import ( connections, Collection, CollectionSchema, DataType, FieldSchema, utility, ) from app.core.config import settings class MilvusHelper: def __init__(self): try: self.collection = None connections.connect( host=os.getenv("MILVUS_HOST", settings.MILVUS_HOST), port=os.getenv("MILVUS_PORT", settings.MILVUS_PORT), ) logger.debug( f"Successfully connect to Milvus with IP: " f"{os.getenv('MILVUS_HOST', settings.MILVUS_HOST)} " f"and PORT: {os.getenv('MILVUS_PORT', settings.MILVUS_PORT)}" ) except Exception as e: logger.error(f"Failed to connect Milvus: {e}") sys.exit(1) def set_collection(self, collection_name: str): try: if self.has_collection(collection_name): self.collection = Collection(name=collection_name) else: raise Exception(f"There has no collection named: {collection_name}") except Exception as e: logger.error(f"Failed to search data to Milvus: {e}") sys.exit(1) # Return if Milvus has the collection @staticmethod def has_collection(collection_name: str): try: status = utility.has_collection(collection_name) return status except Exception as e: logger.error(f"Failed to load data to Milvus: {e}") sys.exit(1) # Create milvus collection if not exists async def create_collection(self, collection_name: str): try: if not self.has_collection(collection_name): field1 = FieldSchema( name="pk", dtype=DataType.INT64, descrition="int64", is_primary=True, auto_id=True, ) field2 = FieldSchema( name="embeddings", dtype=DataType.FLOAT_VECTOR, descrition="float vector", dim=settings.VECTOR_DIMENSION, is_primary=False, ) schema = CollectionSchema( fields=[field1, field2], description="Meeting attendee recommendation", ) self.collection = Collection(name=collection_name, schema=schema) logger.debug(f"Create Milvus collection: {self.collection}") return "OK" except Exception as e: logger.error(f"Failed to load data to Milvus: {e}") sys.exit(1) # Batch insert vectors to milvus collection async def insert(self, collection_name: str, vectors: list): try: await self.create_collection(collection_name) self.collection = Collection(name=collection_name) data = [vectors] mr = self.collection.insert(data) ids = mr.primary_keys self.collection.load() logger.debug( f"Insert vectors to Milvus in collection: " f"{collection_name} with {vectors} rows" ) return ids except Exception as e: logger.error(f"Failed to load data to Milvus: {e}") sys.exit(1) # Create FLAT index on milvus collection async def create_index(self, collection_name: str): try: self.set_collection(collection_name) default_index = { "index_type": "FLAT", "metric_type": settings.METRIC_TYPE, "params": {}, } status = self.collection.create_index( field_name="embeddings", index_params=default_index ) if not status.code: logger.debug( f"Successfully create index in collection: " f"{collection_name} with param: {default_index}" ) return status else: raise Exception(status.message) except Exception as e: logger.error(f"Failed to create index: {e}") sys.exit(1) # Delete Milvus collection async def delete_collection(self, collection_name: str): try: self.set_collection(collection_name) self.collection.drop() logger.debug("Successfully drop collection!") return "ok" except Exception as e: logger.error(f"Failed to drop collection: {e}") sys.exit(1) # Search vector in milvus collection async def search_vectors(self, collection_name: str, vectors: list, top_k: int): # status = utility.list_collections() try: self.set_collection(collection_name) search_params = { "metric_type": settings.METRIC_TYPE, } try: res = self.collection.search( vectors, anns_field="embeddings", param=search_params, limit=top_k ) except BaseException: self.collection.load() res = self.collection.search( vectors, anns_field="embeddings", param=search_params, limit=top_k ) pk_list = [] for hits in res: for hit in hits.ids: pk_list.append(hit) return pk_list except Exception as e: logger.error(f"Failed to search vectors in Milvus: {e}") sys.exit(1) # Get the number of milvus collection async def count(self, collection_name: str): try: self.set_collection(collection_name) num = self.collection.num_entities return num except Exception as e: logger.error(f"Failed to count vectors in Milvus: {e}") sys.exit(1) # Query vector by primary key async def query_vector_by_pk(self, collection_name: str, pk: int): try: self.set_collection(collection_name) expr = f"pk in [{pk}]" try: res = self.collection.query(expr=expr, output_fields=["embeddings"]) except BaseException: self.collection.load() res = self.collection.query(expr=expr, output_fields=["embeddings"]) vector = res[0]["embeddings"] return vector except Exception as e: logger.error(f"Failed to query vector in Milvus: {e}") sys.exit(1) my_milvus = {} async def get_milvus_cli() -> None: MILVUS_CLI = MilvusHelper() my_milvus.update({"cli": MILVUS_CLI})