diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 70a60aa2..59da1b54 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -28,10 +28,10 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.get(id, None) + async def get_by_id(self, id: str) -> dict[str, Any]: + return self._data.get(id, {}) - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: return [ ( {k: v for k, v in self._data[id].items()} @@ -51,7 +51,7 @@ class JsonKVStorage(BaseKVStorage): async def drop(self) -> None: self._data = {} - async def get_by_status_and_ids( + async def get_by_status( self, status: str ) -> Union[list[dict[str, Any]], None]: result = [v for _, v in self._data.items() if v["status"] == status] diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index 8bd972c6..603487bc 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -72,7 +72,7 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} + self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Loaded document status storage with {len(self._data)} records") async def filter_keys(self, data: list[str]) -> set[str]: @@ -112,10 +112,9 @@ class JsonDocStatusStorage(DocStatusStorage): """ self._data.update(data) await self.index_done_callback() - return data - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.get(id) + async def get_by_id(self, id: str) -> dict[str, Any]: + return self._data.get(id, {}) async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: """Get document status by ID""" diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index ce703dfb..eb896b63 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -32,10 +32,10 @@ class MongoKVStorage(BaseKVStorage): async def all_keys(self) -> list[str]: return [x["_id"] for x in self._data.find({}, {"_id": 1})] - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any]: return self._data.find_one({"_id": id}) - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: return list(self._data.find({"_id": {"$in": ids}})) async def filter_keys(self, data: list[str]) -> set[str]: @@ -77,7 +77,7 @@ class MongoKVStorage(BaseKVStorage): """Drop the collection""" await self._data.drop() - async def get_by_status_and_ids( + async def get_by_status( self, status: str ) -> Union[list[dict[str, Any]], None]: """Get documents by status and ids""" diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 3c064eba..0e55194d 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any]: """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -191,12 +191,9 @@ class OracleKVStorage(BaseKVStorage): res = {} for row in array_res: res[row["id"]] = row - else: - res = await self.db.query(SQL, params) - if res: return res else: - return None + return await self.db.query(SQL, params) async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" @@ -211,7 +208,7 @@ class OracleKVStorage(BaseKVStorage): else: return None - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -230,14 +227,9 @@ class OracleKVStorage(BaseKVStorage): for row in res: dict_res[row["mode"]][row["id"]] = row res = [{k: v} for k, v in dict_res.items()] - if res: - data = res # [{"data":i} for i in res] - # print(data) - return data - else: - return None + return res - async def get_by_status_and_ids( + async def get_by_status( self, status: str ) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index ba11fea7..d966fd85 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -183,7 +183,7 @@ class PGKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any]: """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -192,12 +192,9 @@ class PGKVStorage(BaseKVStorage): res = {} for row in array_res: res[row["id"]] = row - else: - res = await self.db.query(sql, params) - if res: return res else: - return None + return await self.db.query(sql, params) async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" @@ -213,7 +210,7 @@ class PGKVStorage(BaseKVStorage): return None # Query by id - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get doc_chunks data by id""" sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -230,15 +227,11 @@ class PGKVStorage(BaseKVStorage): dict_res[mode] = {} for row in array_res: dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] + return [{k: v} for k, v in dict_res.items()] else: - res = await self.db.query(sql, params, multirows=True) - if res: - return res - else: - return None + return await self.db.query(sql, params, multirows=True) - async def get_by_status_and_ids( + async def get_by_status( self, status: str ) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" @@ -454,12 +447,12 @@ class PGDocStatusStorage(DocStatusStorage): existed = set([element["id"] for element in result]) return set(data) - existed - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" params = {"workspace": self.db.workspace, "id": id} result = await self.db.query(sql, params, True) if result is None or result == []: - return None + return {} else: return DocProcessingStatus( content_length=result[0]["content_length"], diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 095cc3b6..7c5c7030 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -29,7 +29,7 @@ class RedisKVStorage(BaseKVStorage): data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: pipe = self._redis.pipeline() for id in ids: pipe.get(f"{self.namespace}:{id}") @@ -59,7 +59,7 @@ class RedisKVStorage(BaseKVStorage): if keys: await self._redis.delete(*keys) - async def get_by_status_and_ids( + async def get_by_status( self, status: str ) -> Union[list[dict[str, Any]], None]: pipe = self._redis.pipeline() diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index b8e6e985..55dbe303 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -108,31 +108,20 @@ class TiDBKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - """根据 id 获取 doc_full 数据.""" + async def get_by_id(self, id: str) -> dict[str, Any]: + """Fetch doc_full data by id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL, params) - if res: - data = res # {"data":res} - # print (data) - return data - else: - return None + return await self.db.query(SQL, params) # Query by id - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: - """根据 id 获取 doc_chunks 数据""" + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Fetch doc_chunks data by id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) - res = await self.db.query(SQL, multirows=True) - if res: - data = res # [{"data":i} for i in res] - return data - else: - return None + return await self.db.query(SQL, multirows=True) async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" @@ -333,7 +322,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): merge_sql = SQL_TEMPLATES["insert_relationship"] await self.db.execute(merge_sql, data) - async def get_by_status_and_ids( + async def get_by_status( self, status: str ) -> Union[list[dict[str, Any]], None]: SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]