Merge pull request #768 from ArnoChenFx/fix-backends

Fix various runtime errors
This commit is contained in:
zrguo
2025-02-14 09:39:27 +08:00
committed by GitHub
5 changed files with 16 additions and 16 deletions

View File

@@ -6,7 +6,6 @@ password = your-password
[mongodb] [mongodb]
uri = mongodb+srv://name:password@your-cluster-address uri = mongodb+srv://name:password@your-cluster-address
database = lightrag database = lightrag
graph = false
[redis] [redis]
uri=redis://localhost:6379/1 uri=redis://localhost:6379/1

View File

@@ -32,8 +32,8 @@ class MilvusVectorDBStorage(BaseVectorStorage):
) )
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError( raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"

View File

@@ -44,7 +44,7 @@ class MongoKVStorage(BaseKVStorage):
database = client.get_database( database = client.get_database(
os.environ.get( os.environ.get(
"MONGO_DATABASE", "MONGO_DATABASE",
mongo_database=config.get("mongodb", "database", fallback="LightRAG"), config.get("mongodb", "database", fallback="LightRAG"),
) )
) )
self._data = database.get_collection(self.namespace) self._data = database.get_collection(self.namespace)

View File

@@ -15,7 +15,6 @@ if not pm.is_installed("qdrant_client"):
from qdrant_client import QdrantClient, models from qdrant_client import QdrantClient, models
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -61,8 +60,8 @@ class QdrantVectorDBStorage(BaseVectorStorage):
client.create_collection(collection_name, **kwargs) client.create_collection(collection_name, **kwargs)
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None: if cosine_threshold is None:
raise ValueError( raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
@@ -138,12 +137,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
query_vector=embedding[0], query_vector=embedding[0],
limit=top_k, limit=top_k,
with_payload=True, with_payload=True,
score_threshold=self.cosine_better_than_threshold,
) )
logger.debug(f"query result: {results}") logger.debug(f"query result: {results}")
# 添加余弦相似度过滤
filtered_results = [ return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
dp for dp in results if dp.score >= self.cosine_better_than_threshold
]
return [
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
]

View File

@@ -80,7 +80,12 @@ STORAGE_IMPLEMENTATIONS = {
"required_methods": ["query", "upsert"], "required_methods": ["query", "upsert"],
}, },
"DOC_STATUS_STORAGE": { "DOC_STATUS_STORAGE": {
"implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"], "implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_pending_docs"], "required_methods": ["get_pending_docs"],
}, },
} }
@@ -421,7 +426,7 @@ class LightRAG:
# Verify storage implementation compatibility # Verify storage implementation compatibility
self.verify_storage_implementation(storage_type, storage_name) self.verify_storage_implementation(storage_type, storage_name)
# Check environment variables # Check environment variables
self.check_storage_env_vars(storage_name) # self.check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields # Ensure vector_db_storage_cls_kwargs has required fields
default_vector_db_kwargs = { default_vector_db_kwargs = {