Merge pull request #738 from ArnoChenFx/qdrant-backend
add qdrant backend, enable MongoGraphStorage based on config
This commit is contained in:
15
config.ini.example
Normal file
15
config.ini.example
Normal file
@@ -0,0 +1,15 @@
|
||||
[neo4j]
|
||||
uri = neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||
username = neo4j
|
||||
password = your-password
|
||||
|
||||
[mongodb]
|
||||
uri = mongodb+srv://name:password@your-cluster-address
|
||||
database = lightrag
|
||||
graph = false
|
||||
|
||||
[redis]
|
||||
uri=redis://localhost:6379/1
|
||||
|
||||
[qdrant]
|
||||
uri = http://localhost:16333
|
@@ -113,14 +113,26 @@ if milvus_uri:
|
||||
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
||||
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
|
||||
|
||||
# Qdrant config
|
||||
qdrant_uri = config.get("qdrant", "uri", fallback=None)
|
||||
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
|
||||
if qdrant_uri:
|
||||
os.environ["QDRANT_URL"] = qdrant_uri
|
||||
if qdrant_api_key:
|
||||
os.environ["QDRANT_API_KEY"] = qdrant_api_key
|
||||
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
|
||||
|
||||
# MongoDB config
|
||||
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
||||
mongo_database = config.get("mongodb", "LightRAG", fallback=None)
|
||||
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
|
||||
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
|
||||
if mongo_uri:
|
||||
os.environ["MONGO_URI"] = mongo_uri
|
||||
os.environ["MONGO_DATABASE"] = mongo_database
|
||||
rag_storage_config.KV_STORAGE = "MongoKVStorage"
|
||||
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
|
||||
if mongo_graph:
|
||||
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
|
127
lightrag/kg/qdrant_impl.py
Normal file
127
lightrag/kg/qdrant_impl.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import asyncio
|
||||
import os
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("qdrant_client"):
|
||||
pm.install("qdrant_client")
|
||||
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
content: str, prefix: str = "", style: str = "simple"
|
||||
) -> str:
|
||||
"""
|
||||
Generate a UUID based on the content and support multiple formats.
|
||||
|
||||
:param content: The content used to generate the UUID.
|
||||
:param style: The format of the UUID, optional values are "simple", "hyphenated", "urn".
|
||||
:return: A UUID that meets the requirements of Qdrant.
|
||||
"""
|
||||
if not content:
|
||||
raise ValueError("Content must not be empty.")
|
||||
|
||||
# Use the hash value of the content to create a UUID.
|
||||
hashed_content = hashlib.sha256((prefix + content).encode("utf-8")).digest()
|
||||
generated_uuid = uuid.UUID(bytes=hashed_content[:16], version=4)
|
||||
|
||||
# Return the UUID according to the specified format.
|
||||
if style == "simple":
|
||||
return generated_uuid.hex
|
||||
elif style == "hyphenated":
|
||||
return str(generated_uuid)
|
||||
elif style == "urn":
|
||||
return f"urn:uuid:{generated_uuid}"
|
||||
else:
|
||||
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: QdrantClient, collection_name: str, **kwargs
|
||||
):
|
||||
if client.collection_exists(collection_name):
|
||||
return
|
||||
client.create_collection(collection_name, **kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
self._client = QdrantClient(
|
||||
url=os.environ.get("QDRANT_URL"),
|
||||
api_key=os.environ.get("QDRANT_API_KEY", None),
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
vectors_config=models.VectorParams(
|
||||
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
|
||||
),
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
async def wrapped_task(batch):
|
||||
result = await self.embedding_func(batch)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||
pbar = tqdm_async(
|
||||
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||
)
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
|
||||
list_points = []
|
||||
for i, d in enumerate(list_data):
|
||||
list_points.append(
|
||||
models.PointStruct(
|
||||
id=compute_mdhash_id_for_qdrant(d["id"]),
|
||||
vector=embeddings[i],
|
||||
payload=d,
|
||||
)
|
||||
)
|
||||
|
||||
results = self._client.upsert(
|
||||
collection_name=self.namespace, points=list_points, wait=True
|
||||
)
|
||||
return results
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
query_vector=embedding[0],
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
)
|
||||
logger.debug(f"query result: {results}")
|
||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
@@ -59,6 +59,7 @@ STORAGES = {
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user