From 71a18d1de97f7ab1d62174ff2493590e39f8d74b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 13:31:12 +0100 Subject: [PATCH] updated clean of what implemented on BaseKVStorage --- lightrag/base.py | 4 ++-- lightrag/kg/json_kv_impl.py | 6 +++--- lightrag/kg/mongo_impl.py | 7 +++++-- lightrag/kg/oracle_impl.py | 8 +++++--- lightrag/kg/postgres_impl.py | 10 ++++++---- lightrag/kg/redis_impl.py | 11 +++++++---- lightrag/kg/tidb_impl.py | 8 +++++--- 7 files changed, 33 insertions(+), 21 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 3d4fc022..8efbe8a2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -121,11 +121,11 @@ class BaseKVStorage(StorageNameSpace): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: raise NotImplementedError - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return un-exist keys""" raise NotImplementedError - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: raise NotImplementedError async def drop(self) -> None: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 3ab5b966..5683801f 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any from lightrag.base import ( BaseKVStorage, @@ -25,7 +25,7 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -38,7 +38,7 @@ class JsonKVStorage(BaseKVStorage): for id in ids ] - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: return set(data) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 39bb9f18..44820ecf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -60,14 +60,14 @@ class MongoKVStorage(BaseKVStorage): # Ensure collection exists create_collection_if_not_exists(uri, database.name, self._collection_name) - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) return await cursor.to_list() - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) existing_ids = {str(x["_id"]) async for x in cursor} return data - existing_ids @@ -107,6 +107,9 @@ class MongoKVStorage(BaseKVStorage): else: return None + async def index_done_callback(self) -> None: + pass + async def drop(self) -> None: """Drop the collection""" await self._data.drop() diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 197d101e..95d888b3 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -232,7 +232,7 @@ class OracleKVStorage(BaseKVStorage): res = [{k: v} for k, v in dict_res.items()] return res - async def filter_keys(self, keys: list[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -248,7 +248,7 @@ class OracleKVStorage(BaseKVStorage): return set(keys) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { @@ -314,6 +314,8 @@ class OracleKVStorage(BaseKVStorage): ): logger.info("full doc and chunk data had been saved into oracle db!") + async def drop(self) -> None: + raise NotImplementedError @dataclass class OracleVectorDBStorage(BaseVectorStorage): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 5dbc6a8e..98f9c495 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,7 +4,7 @@ import json import os import time from dataclasses import dataclass -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import numpy as np import pipmaster as pm @@ -185,7 +185,7 @@ class PGKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -240,7 +240,7 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) - async def filter_keys(self, keys: List[str]) -> Set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -261,7 +261,7 @@ class PGKVStorage(BaseKVStorage): print(params) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): pass elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): @@ -294,6 +294,8 @@ class PGKVStorage(BaseKVStorage): ): logger.info("full doc and chunk data had been saved into postgresql db!") + async def drop(self) -> None: + raise NotImplementedError @dataclass class PGVectorStorage(BaseVectorStorage): diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index ed8f46f9..f735c72a 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,5 +1,5 @@ import os -from typing import Any, Union +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -28,7 +28,7 @@ class RedisKVStorage(BaseKVStorage): self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None @@ -39,7 +39,7 @@ class RedisKVStorage(BaseKVStorage): results = await pipe.execute() return [json.loads(result) if result else None for result in results] - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: pipe = self._redis.pipeline() for key in data: pipe.exists(f"{self.namespace}:{key}") @@ -48,7 +48,7 @@ class RedisKVStorage(BaseKVStorage): existing_ids = {data[i] for i, exists in enumerate(results) if exists} return set(data) - existing_ids - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pipe = self._redis.pipeline() for k, v in tqdm_async(data.items(), desc="Upserting"): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) @@ -61,3 +61,6 @@ class RedisKVStorage(BaseKVStorage): keys = await self._redis.keys(f"{self.namespace}:*") if keys: await self._redis.delete(*keys) + + async def index_done_callback(self) -> None: + pass \ No newline at end of file diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index a5a5c80d..6f388e7f 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -110,7 +110,7 @@ class TiDBKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Fetch doc_full data by id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} @@ -125,7 +125,7 @@ class TiDBKVStorage(BaseKVStorage): ) return await self.db.query(SQL, multirows=True) - async def filter_keys(self, keys: list[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """过滤掉重复内容""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -147,7 +147,7 @@ class TiDBKVStorage(BaseKVStorage): return data ################ INSERT full_doc AND chunks ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): @@ -207,6 +207,8 @@ class TiDBKVStorage(BaseKVStorage): ): logger.info("full doc and chunk data had been saved into TiDB db!") + async def drop(self) -> None: + raise NotImplementedError @dataclass class TiDBVectorDBStorage(BaseVectorStorage):