From cff415d91f2d0f86c77087a7f21d2f03d0b7b760 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 8 Feb 2025 23:18:12 +0100 Subject: [PATCH] implemented method and cleaned the mess --- lightrag/kg/json_kv_impl.py | 63 +++---------------------------- lightrag/kg/jsondocstatus_impl.py | 6 +-- lightrag/kg/mongo_impl.py | 30 +++++++-------- lightrag/kg/oracle_impl.py | 31 ++++----------- lightrag/kg/postgres_impl.py | 17 ++++++--- lightrag/kg/redis_impl.py | 28 +++++++------- lightrag/kg/tidb_impl.py | 16 +++++--- 7 files changed, 66 insertions(+), 125 deletions(-) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index f19463ce..6ee49f7c 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,53 +1,3 @@ -""" -JsonDocStatus Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import asyncio import os from dataclasses import dataclass @@ -58,12 +8,10 @@ from lightrag.utils import ( load_json, write_json, ) - from lightrag.base import ( BaseKVStorage, ) - @dataclass class JsonKVStorage(BaseKVStorage): def __post_init__(self): @@ -79,13 +27,13 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id: str): + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return self._data.get(id, None) - async def get_by_ids(self, ids: list[str]): + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: return [ ( - {k: v for k, v in self._data[id].items() } + {k: v for k, v in self._data[id].items()} if self._data.get(id, None) else None ) @@ -95,12 +43,11 @@ class JsonKVStorage(BaseKVStorage): async def filter_keys(self, data: list[str]) -> set[str]: return set([s for s in data if s not in self._data]) - async def upsert(self, data: dict[str, dict[str, Any]]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) - return left_data - async def drop(self): + async def drop(self) -> None: self._data = {} async def get_by_status_and_ids( diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index 8f326170..8bd972c6 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -50,7 +50,7 @@ Usage: import os from dataclasses import dataclass -from typing import Union, Dict +from typing import Any, Union, Dict from lightrag.utils import ( logger, @@ -104,7 +104,7 @@ class JsonDocStatusStorage(DocStatusStorage): """Save data to file after indexing""" write_json(self._data, self._file_name) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: """Update or insert document status Args: @@ -114,7 +114,7 @@ class JsonDocStatusStorage(DocStatusStorage): await self.index_done_callback() return data - async def get_by_id(self, id: str): + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return self._data.get(id) async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 7afc4240..d0598ca4 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -12,7 +12,7 @@ if not pm.is_installed("motor"): from pymongo import MongoClient from motor.motor_asyncio import AsyncIOMotorClient -from typing import Union, List, Tuple +from typing import Any, TypeVar, Union, List, Tuple from ..utils import logger from ..base import BaseKVStorage, BaseGraphStorage @@ -32,18 +32,11 @@ class MongoKVStorage(BaseKVStorage): async def all_keys(self) -> list[str]: return [x["_id"] for x in self._data.find({}, {"_id": 1})] - async def get_by_id(self, id): + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return self._data.find_one({"_id": id}) - async def get_by_ids(self, ids, fields=None): - if fields is None: - return list(self._data.find({"_id": {"$in": ids}})) - return list( - self._data.find( - {"_id": {"$in": ids}}, - {field: 1 for field in fields}, - ) - ) + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + return list(self._data.find({"_id": {"$in": ids}})) async def filter_keys(self, data: list[str]) -> set[str]: existing_ids = [ @@ -51,7 +44,7 @@ class MongoKVStorage(BaseKVStorage): ] return set([s for s in data if s not in existing_ids]) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): for mode, items in data.items(): for k, v in tqdm_async(items.items(), desc="Upserting"): @@ -66,7 +59,6 @@ class MongoKVStorage(BaseKVStorage): for k, v in tqdm_async(data.items(), desc="Upserting"): self._data.update_one({"_id": k}, {"$set": v}, upsert=True) data[k]["_id"] = k - return data async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -81,9 +73,15 @@ class MongoKVStorage(BaseKVStorage): else: return None - async def drop(self): - """ """ - pass + async def drop(self) -> None: + """Drop the collection""" + await self._data.drop() + + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: + """Get documents by status and ids""" + return self._data.find({"status": status}) @dataclass diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index a1a05759..9438c323 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -4,7 +4,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union +from typing import Any, TypeVar, Union import numpy as np import array import pipmaster as pm @@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict, None]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -211,7 +211,7 @@ class OracleKVStorage(BaseKVStorage): else: return None - async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -238,15 +238,10 @@ class OracleKVStorage(BaseKVStorage): return None async def get_by_status_and_ids( - self, status: str, ids: list[str] - ) -> Union[list[dict], None]: + self, status: str + ) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" - if ids is not None: - SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - else: - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} res = await self.db.query(SQL, params, multirows=True) if res: @@ -270,7 +265,7 @@ class OracleKVStorage(BaseKVStorage): return set(keys) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { @@ -328,14 +323,6 @@ class OracleKVStorage(BaseKVStorage): } await self.db.execute(upsert_sql, _data) - return None - - async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format( - table_name=namespace_to_table_name(self.namespace) - ) - params = {"workspace": self.db.workspace, "id": id, "status": status} - await self.db.execute(SQL, params) async def index_done_callback(self): if is_namespace( @@ -343,8 +330,7 @@ class OracleKVStorage(BaseKVStorage): (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), ): logger.info("full doc and chunk data had been saved into oracle db!") - - + @dataclass class OracleVectorDBStorage(BaseVectorStorage): # should pass db object to self.db @@ -745,7 +731,6 @@ SQL_TEMPLATES = { "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = :id and a.workspace = :workspace) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 8884d92e..dccb2d7f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -30,7 +30,6 @@ from ..base import ( DocStatus, DocProcessingStatus, BaseGraphStorage, - T, ) from ..namespace import NameSpace, is_namespace @@ -184,7 +183,7 @@ class PGKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict, None]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -214,7 +213,7 @@ class PGKVStorage(BaseKVStorage): return None # Query by id - async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: """Get doc_chunks data by id""" sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -238,6 +237,14 @@ class PGKVStorage(BaseKVStorage): return res else: return None + + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + return await self.db.query(SQL, params, multirows=True) async def all_keys(self) -> list[dict]: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -270,7 +277,7 @@ class PGKVStorage(BaseKVStorage): print(params) ################ INSERT METHODS ################ - async def upsert(self, data: Dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): pass elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): @@ -447,7 +454,7 @@ class PGDocStatusStorage(DocStatusStorage): existed = set([element["id"] for element in result]) return set(data) - existed - async def get_by_id(self, id: str) -> Union[T, None]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" params = {"workspace": self.db.workspace, "id": id} result = await self.db.query(sql, params, True) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 147ea5f3..15faa843 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,4 +1,5 @@ import os +from typing import Any, TypeVar, Union from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -28,21 +29,11 @@ class RedisKVStorage(BaseKVStorage): data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None - async def get_by_ids(self, ids, fields=None): + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: pipe = self._redis.pipeline() for id in ids: pipe.get(f"{self.namespace}:{id}") results = await pipe.execute() - - if fields: - # Filter fields if specified - return [ - {field: value.get(field) for field in fields if field in value} - if (value := json.loads(result)) - else None - for result in results - ] - return [json.loads(result) if result else None for result in results] async def filter_keys(self, data: list[str]) -> set[str]: @@ -54,7 +45,7 @@ class RedisKVStorage(BaseKVStorage): existing_ids = {data[i] for i, exists in enumerate(results) if exists} return set(data) - existing_ids - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: pipe = self._redis.pipeline() for k, v in tqdm_async(data.items(), desc="Upserting"): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) @@ -62,9 +53,18 @@ class RedisKVStorage(BaseKVStorage): for k in data: data[k]["_id"] = k - return data - async def drop(self): + async def drop(self) -> None: keys = await self._redis.keys(f"{self.namespace}:*") if keys: await self._redis.delete(*keys) + + async def get_by_status_and_ids( + self, status: str, + ) -> Union[list[dict[str, Any]], None]: + pipe = self._redis.pipeline() + for key in await self._redis.keys(f"{self.namespace}:*"): + pipe.hgetall(key) + results = await pipe.execute() + return [data for data in results if data.get("status") == status] or None + diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index cb819d47..97d5794f 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Union +from typing import Any, TypeVar, Union import numpy as np import pipmaster as pm @@ -108,7 +108,7 @@ class TiDBKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict, None]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: """根据 id 获取 doc_full 数据.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} @@ -122,16 +122,14 @@ class TiDBKVStorage(BaseKVStorage): return None # Query by id - async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: """根据 id 获取 doc_chunks 数据""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) - # print("get_by_ids:"+SQL) res = await self.db.query(SQL, multirows=True) if res: data = res # [{"data":i} for i in res] - # print(data) return data else: return None @@ -158,7 +156,7 @@ class TiDBKVStorage(BaseKVStorage): return data ################ INSERT full_doc AND chunks ################ - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): @@ -335,6 +333,12 @@ class TiDBVectorDBStorage(BaseVectorStorage): merge_sql = SQL_TEMPLATES["insert_relationship"] await self.db.execute(merge_sql, data) + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict], None]: + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + return await self.db.query(SQL, params, multirows=True) @dataclass class TiDBGraphStorage(BaseGraphStorage):