diff --git a/config.ini.example b/config.ini.example new file mode 100644 index 00000000..e7916b01 --- /dev/null +++ b/config.ini.example @@ -0,0 +1,15 @@ +[neo4j] +uri = neo4j+s://xxxxxxxx.databases.neo4j.io +username = neo4j +password = your-password + +[mongodb] +uri = mongodb+srv://name:password@your-cluster-address +database = lightrag +graph = false + +[redis] +uri=redis://localhost:6379/1 + +[qdrant] +uri = http://localhost:16333 diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 23507dd1..8d13fab0 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -113,14 +113,26 @@ if milvus_uri: os.environ["MILVUS_DB_NAME"] = milvus_db_name rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge" +# Qdrant config +qdrant_uri = config.get("qdrant", "uri", fallback=None) +qdrant_api_key = config.get("qdrant", "apikey", fallback=None) +if qdrant_uri: + os.environ["QDRANT_URL"] = qdrant_uri + if qdrant_api_key: + os.environ["QDRANT_API_KEY"] = qdrant_api_key + rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage" + # MongoDB config mongo_uri = config.get("mongodb", "uri", fallback=None) -mongo_database = config.get("mongodb", "LightRAG", fallback=None) +mongo_database = config.get("mongodb", "database", fallback="LightRAG") +mongo_graph = config.getboolean("mongodb", "graph", fallback=False) if mongo_uri: os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_DATABASE"] = mongo_database rag_storage_config.KV_STORAGE = "MongoKVStorage" rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" + if mongo_graph: + rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage" def get_default_host(binding_type: str) -> str: diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py new file mode 100644 index 00000000..e2f8d3a2 --- /dev/null +++ b/lightrag/kg/qdrant_impl.py @@ -0,0 +1,127 @@ +import asyncio +import os +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass +import numpy as np +import hashlib +import uuid + +from ..utils import logger +from ..base import BaseVectorStorage + +import pipmaster as pm + +if not pm.is_installed("qdrant_client"): + pm.install("qdrant_client") + +from qdrant_client import QdrantClient, models + + +def compute_mdhash_id_for_qdrant( + content: str, prefix: str = "", style: str = "simple" +) -> str: + """ + Generate a UUID based on the content and support multiple formats. + + :param content: The content used to generate the UUID. + :param style: The format of the UUID, optional values are "simple", "hyphenated", "urn". + :return: A UUID that meets the requirements of Qdrant. + """ + if not content: + raise ValueError("Content must not be empty.") + + # Use the hash value of the content to create a UUID. + hashed_content = hashlib.sha256((prefix + content).encode("utf-8")).digest() + generated_uuid = uuid.UUID(bytes=hashed_content[:16], version=4) + + # Return the UUID according to the specified format. + if style == "simple": + return generated_uuid.hex + elif style == "hyphenated": + return str(generated_uuid) + elif style == "urn": + return f"urn:uuid:{generated_uuid}" + else: + raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.") + + +@dataclass +class QdrantVectorDBStorage(BaseVectorStorage): + @staticmethod + def create_collection_if_not_exist( + client: QdrantClient, collection_name: str, **kwargs + ): + if client.collection_exists(collection_name): + return + client.create_collection(collection_name, **kwargs) + + def __post_init__(self): + self._client = QdrantClient( + url=os.environ.get("QDRANT_URL"), + api_key=os.environ.get("QDRANT_API_KEY", None), + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + QdrantVectorDBStorage.create_collection_if_not_exist( + self._client, + self.namespace, + vectors_config=models.VectorParams( + size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE + ), + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + list_data = [ + { + "id": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + + list_points = [] + for i, d in enumerate(list_data): + list_points.append( + models.PointStruct( + id=compute_mdhash_id_for_qdrant(d["id"]), + vector=embeddings[i], + payload=d, + ) + ) + + results = self._client.upsert( + collection_name=self.namespace, points=list_points, wait=True + ) + return results + + async def query(self, query, top_k=5): + embedding = await self.embedding_func([query]) + results = self._client.search( + collection_name=self.namespace, + query_vector=embedding[0], + limit=top_k, + with_payload=True, + ) + logger.debug(f"query result: {results}") + return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d92b5ea4..347f0f4c 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -59,6 +59,7 @@ STORAGES = { "GremlinStorage": ".kg.gremlin_impl", "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", + "QdrantVectorDBStorage": ".kg.qdrant_impl", }