123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import os
- import sys
- from loguru import logger
- from pymilvus import (
- connections,
- Collection,
- CollectionSchema,
- DataType,
- FieldSchema,
- utility,
- )
- from app.core.config import settings
- class MilvusHelper:
- def __init__(self):
- try:
- self.collection = None
- connections.connect(
- host=os.getenv("MILVUS_HOST", settings.MILVUS_HOST),
- port=os.getenv("MILVUS_PORT", settings.MILVUS_PORT),
- )
- logger.debug(
- f"Successfully connect to Milvus with IP: "
- f"{os.getenv('MILVUS_HOST', settings.MILVUS_HOST)} "
- f"and PORT: {os.getenv('MILVUS_PORT', settings.MILVUS_PORT)}"
- )
- except Exception as e:
- logger.error(f"Failed to connect Milvus: {e}")
- sys.exit(1)
- def set_collection(self, collection_name: str):
- try:
- if self.has_collection(collection_name):
- self.collection = Collection(name=collection_name)
- else:
- raise Exception(f"There has no collection named: {collection_name}")
- except Exception as e:
- logger.error(f"Failed to search data to Milvus: {e}")
- sys.exit(1)
- # Return if Milvus has the collection
- @staticmethod
- def has_collection(collection_name: str):
- try:
- status = utility.has_collection(collection_name)
- return status
- except Exception as e:
- logger.error(f"Failed to load data to Milvus: {e}")
- sys.exit(1)
- # Create milvus collection if not exists
- async def create_collection(self, collection_name: str):
- try:
- if not self.has_collection(collection_name):
- field1 = FieldSchema(
- name="pk",
- dtype=DataType.INT64,
- descrition="int64",
- is_primary=True,
- auto_id=True,
- )
- field2 = FieldSchema(
- name="embeddings",
- dtype=DataType.FLOAT_VECTOR,
- descrition="float vector",
- dim=settings.VECTOR_DIMENSION,
- is_primary=False,
- )
- schema = CollectionSchema(
- fields=[field1, field2],
- description="Meeting attendee recommendation",
- )
- self.collection = Collection(name=collection_name, schema=schema)
- logger.debug(f"Create Milvus collection: {self.collection}")
- return "OK"
- except Exception as e:
- logger.error(f"Failed to load data to Milvus: {e}")
- sys.exit(1)
- # Batch insert vectors to milvus collection
- async def insert(self, collection_name: str, vectors: list):
- try:
- await self.create_collection(collection_name)
- self.collection = Collection(name=collection_name)
- data = [vectors]
- mr = self.collection.insert(data)
- ids = mr.primary_keys
- self.collection.load()
- logger.debug(
- f"Insert vectors to Milvus in collection: "
- f"{collection_name} with {vectors} rows"
- )
- return ids
- except Exception as e:
- logger.error(f"Failed to load data to Milvus: {e}")
- sys.exit(1)
- # Create FLAT index on milvus collection
- async def create_index(self, collection_name: str):
- try:
- self.set_collection(collection_name)
- default_index = {
- "index_type": "FLAT",
- "metric_type": settings.METRIC_TYPE,
- "params": {},
- }
- status = self.collection.create_index(
- field_name="embeddings", index_params=default_index
- )
- if not status.code:
- logger.debug(
- f"Successfully create index in collection: "
- f"{collection_name} with param: {default_index}"
- )
- return status
- else:
- raise Exception(status.message)
- except Exception as e:
- logger.error(f"Failed to create index: {e}")
- sys.exit(1)
- # Delete Milvus collection
- async def delete_collection(self, collection_name: str):
- try:
- self.set_collection(collection_name)
- self.collection.drop()
- logger.debug("Successfully drop collection!")
- return "ok"
- except Exception as e:
- logger.error(f"Failed to drop collection: {e}")
- sys.exit(1)
- # Search vector in milvus collection
- async def search_vectors(self, collection_name: str, vectors: list, top_k: int):
- # status = utility.list_collections()
- try:
- self.set_collection(collection_name)
- search_params = {
- "metric_type": settings.METRIC_TYPE,
- }
- try:
- res = self.collection.search(
- vectors, anns_field="embeddings", param=search_params, limit=top_k
- )
- except BaseException:
- self.collection.load()
- res = self.collection.search(
- vectors, anns_field="embeddings", param=search_params, limit=top_k
- )
- pk_list = []
- for hits in res:
- for hit in hits.ids:
- pk_list.append(hit)
- return pk_list
- except Exception as e:
- logger.error(f"Failed to search vectors in Milvus: {e}")
- sys.exit(1)
- # Get the number of milvus collection
- async def count(self, collection_name: str):
- try:
- self.set_collection(collection_name)
- num = self.collection.num_entities
- return num
- except Exception as e:
- logger.error(f"Failed to count vectors in Milvus: {e}")
- sys.exit(1)
- # Query vector by primary key
- async def query_vector_by_pk(self, collection_name: str, pk: int):
- try:
- self.set_collection(collection_name)
- expr = f"pk in [{pk}]"
- try:
- res = self.collection.query(expr=expr, output_fields=["embeddings"])
- except BaseException:
- self.collection.load()
- res = self.collection.query(expr=expr, output_fields=["embeddings"])
- vector = res[0]["embeddings"]
- return vector
- except Exception as e:
- logger.error(f"Failed to query vector in Milvus: {e}")
- sys.exit(1)
- my_milvus = {}
- async def get_milvus_cli() -> None:
- MILVUS_CLI = MilvusHelper()
- my_milvus.update({"cli": MILVUS_CLI})
|