updated clean of what implemented on BaseKVStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:31:12 +01:00
parent 3eba41aab6
commit 71a18d1de9
7 changed files with 33 additions and 21 deletions

View File

@@ -121,11 +121,11 @@ class BaseKVStorage(StorageNameSpace):
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
raise NotImplementedError raise NotImplementedError
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return un-exist keys""" """Return un-exist keys"""
raise NotImplementedError raise NotImplementedError
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
raise NotImplementedError raise NotImplementedError
async def drop(self) -> None: async def drop(self) -> None:

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any
from lightrag.base import ( from lightrag.base import (
BaseKVStorage, BaseKVStorage,
@@ -25,7 +25,7 @@ class JsonKVStorage(BaseKVStorage):
async def index_done_callback(self): async def index_done_callback(self):
write_json(self._data, self._file_name) write_json(self._data, self._file_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return self._data.get(id) return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -38,7 +38,7 @@ class JsonKVStorage(BaseKVStorage):
for id in ids for id in ids
] ]
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
return set(data) - set(self._data.keys()) return set(data) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:

View File

@@ -60,14 +60,14 @@ class MongoKVStorage(BaseKVStorage):
# Ensure collection exists # Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name) create_collection_if_not_exists(uri, database.name, self._collection_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}}) cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list() return await cursor.to_list()
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
existing_ids = {str(x["_id"]) async for x in cursor} existing_ids = {str(x["_id"]) async for x in cursor}
return data - existing_ids return data - existing_ids
@@ -107,6 +107,9 @@ class MongoKVStorage(BaseKVStorage):
else: else:
return None return None
async def index_done_callback(self) -> None:
pass
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the collection""" """Drop the collection"""
await self._data.drop() await self._data.drop()

View File

@@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage):
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data based on id.""" """Get doc_full data based on id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id} params = {"workspace": self.db.workspace, "id": id}
@@ -232,7 +232,7 @@ class OracleKVStorage(BaseKVStorage):
res = [{k: v} for k, v in dict_res.items()] res = [{k: v} for k, v in dict_res.items()]
return res return res
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), table_name=namespace_to_table_name(self.namespace),
@@ -248,7 +248,7 @@ class OracleKVStorage(BaseKVStorage):
return set(keys) return set(keys)
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [ list_data = [
{ {
@@ -314,6 +314,8 @@ class OracleKVStorage(BaseKVStorage):
): ):
logger.info("full doc and chunk data had been saved into oracle db!") logger.info("full doc and chunk data had been saved into oracle db!")
async def drop(self) -> None:
raise NotImplementedError
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):

View File

@@ -4,7 +4,7 @@ import json
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import numpy as np import numpy as np
import pipmaster as pm import pipmaster as pm
@@ -185,7 +185,7 @@ class PGKVStorage(BaseKVStorage):
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data by id.""" """Get doc_full data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace] sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id} params = {"workspace": self.db.workspace, "id": id}
@@ -240,7 +240,7 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
async def filter_keys(self, keys: List[str]) -> Set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), table_name=namespace_to_table_name(self.namespace),
@@ -261,7 +261,7 @@ class PGKVStorage(BaseKVStorage):
print(params) print(params)
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
pass pass
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@@ -294,6 +294,8 @@ class PGKVStorage(BaseKVStorage):
): ):
logger.info("full doc and chunk data had been saved into postgresql db!") logger.info("full doc and chunk data had been saved into postgresql db!")
async def drop(self) -> None:
raise NotImplementedError
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Union from typing import Any
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
@@ -28,7 +28,7 @@ class RedisKVStorage(BaseKVStorage):
self._redis = Redis.from_url(redis_url, decode_responses=True) self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}") logger.info(f"Use Redis as KV {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
data = await self._redis.get(f"{self.namespace}:{id}") data = await self._redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None return json.loads(data) if data else None
@@ -39,7 +39,7 @@ class RedisKVStorage(BaseKVStorage):
results = await pipe.execute() results = await pipe.execute()
return [json.loads(result) if result else None for result in results] return [json.loads(result) if result else None for result in results]
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for key in data: for key in data:
pipe.exists(f"{self.namespace}:{key}") pipe.exists(f"{self.namespace}:{key}")
@@ -48,7 +48,7 @@ class RedisKVStorage(BaseKVStorage):
existing_ids = {data[i] for i, exists in enumerate(results) if exists} existing_ids = {data[i] for i, exists in enumerate(results) if exists}
return set(data) - existing_ids return set(data) - existing_ids
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for k, v in tqdm_async(data.items(), desc="Upserting"): for k, v in tqdm_async(data.items(), desc="Upserting"):
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) pipe.set(f"{self.namespace}:{k}", json.dumps(v))
@@ -61,3 +61,6 @@ class RedisKVStorage(BaseKVStorage):
keys = await self._redis.keys(f"{self.namespace}:*") keys = await self._redis.keys(f"{self.namespace}:*")
if keys: if keys:
await self._redis.delete(*keys) await self._redis.delete(*keys)
async def index_done_callback(self) -> None:
pass

View File

@@ -110,7 +110,7 @@ class TiDBKVStorage(BaseKVStorage):
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Fetch doc_full data by id.""" """Fetch doc_full data by id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"id": id} params = {"id": id}
@@ -125,7 +125,7 @@ class TiDBKVStorage(BaseKVStorage):
) )
return await self.db.query(SQL, multirows=True) return await self.db.query(SQL, multirows=True)
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""过滤掉重复内容""" """过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), table_name=namespace_to_table_name(self.namespace),
@@ -147,7 +147,7 @@ class TiDBKVStorage(BaseKVStorage):
return data return data
################ INSERT full_doc AND chunks ################ ################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -207,6 +207,8 @@ class TiDBKVStorage(BaseKVStorage):
): ):
logger.info("full doc and chunk data had been saved into TiDB db!") logger.info("full doc and chunk data had been saved into TiDB db!")
async def drop(self) -> None:
raise NotImplementedError
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):