Fix linting
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
@@ -20,7 +19,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
user_collection_settings = config.get("collection_settings", {})
|
||||
|
@@ -30,7 +30,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
# Where to save index file if you want persistent storage
|
||||
|
@@ -35,7 +35,9 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
self._client = MilvusClient(
|
||||
@@ -111,7 +113,10 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
||||
data=embedding,
|
||||
limit=top_k,
|
||||
output_fields=list(self.meta_fields),
|
||||
search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}},
|
||||
search_params={
|
||||
"metric_type": "COSINE",
|
||||
"params": {"radius": self.cosine_better_than_threshold},
|
||||
},
|
||||
)
|
||||
print(results)
|
||||
return [
|
||||
|
@@ -82,7 +82,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
self._client_file_name = os.path.join(
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import array
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
# import html
|
||||
# import os
|
||||
@@ -326,7 +325,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
|
@@ -306,7 +306,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
def _upsert_chunks(self, item: dict):
|
||||
@@ -424,9 +426,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
keys = ",".join([f"'{_id}'" for _id in data])
|
||||
sql = (
|
||||
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||
)
|
||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||
result = await self.db.query(sql, multirows=True)
|
||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||
if result is None:
|
||||
|
@@ -64,7 +64,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
self._client = QdrantClient(
|
||||
@@ -140,5 +142,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
logger.debug(f"query result: {results}")
|
||||
# 添加余弦相似度过滤
|
||||
filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold]
|
||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results]
|
||||
filtered_results = [
|
||||
dp for dp in results if dp.score >= self.cosine_better_than_threshold
|
||||
]
|
||||
return [
|
||||
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
|
||||
]
|
||||
|
@@ -222,7 +222,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
|
Reference in New Issue
Block a user