updated clean of what implemented on BaseKVStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:31:12 +01:00
parent 3eba41aab6
commit 71a18d1de9
7 changed files with 33 additions and 21 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Union
from typing import Any
from lightrag.base import (
BaseKVStorage,
@@ -25,7 +25,7 @@ class JsonKVStorage(BaseKVStorage):
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -38,7 +38,7 @@ class JsonKVStorage(BaseKVStorage):
for id in ids
]
async def filter_keys(self, data: set[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
return set(data) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:

View File

@@ -60,14 +60,14 @@ class MongoKVStorage(BaseKVStorage):
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list()
async def filter_keys(self, data: set[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
existing_ids = {str(x["_id"]) async for x in cursor}
return data - existing_ids
@@ -107,6 +107,9 @@ class MongoKVStorage(BaseKVStorage):
else:
return None
async def index_done_callback(self) -> None:
pass
async def drop(self) -> None:
"""Drop the collection"""
await self._data.drop()

View File

@@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage):
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data based on id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
@@ -232,7 +232,7 @@ class OracleKVStorage(BaseKVStorage):
res = [{k: v} for k, v in dict_res.items()]
return res
async def filter_keys(self, keys: list[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
@@ -248,7 +248,7 @@ class OracleKVStorage(BaseKVStorage):
return set(keys)
################ INSERT METHODS ################
async def upsert(self, data: dict[str, Any]) -> None:
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [
{
@@ -314,6 +314,8 @@ class OracleKVStorage(BaseKVStorage):
):
logger.info("full doc and chunk data had been saved into oracle db!")
async def drop(self) -> None:
raise NotImplementedError
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):

View File

@@ -4,7 +4,7 @@ import json
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import pipmaster as pm
@@ -185,7 +185,7 @@ class PGKVStorage(BaseKVStorage):
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
@@ -240,7 +240,7 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def filter_keys(self, keys: List[str]) -> Set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
@@ -261,7 +261,7 @@ class PGKVStorage(BaseKVStorage):
print(params)
################ INSERT METHODS ################
async def upsert(self, data: dict[str, Any]) -> None:
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@@ -294,6 +294,8 @@ class PGKVStorage(BaseKVStorage):
):
logger.info("full doc and chunk data had been saved into postgresql db!")
async def drop(self) -> None:
raise NotImplementedError
@dataclass
class PGVectorStorage(BaseVectorStorage):

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Union
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import pipmaster as pm
@@ -28,7 +28,7 @@ class RedisKVStorage(BaseKVStorage):
self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
data = await self._redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
@@ -39,7 +39,7 @@ class RedisKVStorage(BaseKVStorage):
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
async def filter_keys(self, data: set[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
pipe = self._redis.pipeline()
for key in data:
pipe.exists(f"{self.namespace}:{key}")
@@ -48,7 +48,7 @@ class RedisKVStorage(BaseKVStorage):
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
return set(data) - existing_ids
async def upsert(self, data: dict[str, Any]) -> None:
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pipe = self._redis.pipeline()
for k, v in tqdm_async(data.items(), desc="Upserting"):
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
@@ -61,3 +61,6 @@ class RedisKVStorage(BaseKVStorage):
keys = await self._redis.keys(f"{self.namespace}:*")
if keys:
await self._redis.delete(*keys)
async def index_done_callback(self) -> None:
pass

View File

@@ -110,7 +110,7 @@ class TiDBKVStorage(BaseKVStorage):
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Fetch doc_full data by id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"id": id}
@@ -125,7 +125,7 @@ class TiDBKVStorage(BaseKVStorage):
)
return await self.db.query(SQL, multirows=True)
async def filter_keys(self, keys: list[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
"""过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
@@ -147,7 +147,7 @@ class TiDBKVStorage(BaseKVStorage):
return data
################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, Any]) -> None:
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -207,6 +207,8 @@ class TiDBKVStorage(BaseKVStorage):
):
logger.info("full doc and chunk data had been saved into TiDB db!")
async def drop(self) -> None:
raise NotImplementedError
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):