From 6480ddee5ddaae84f47328295e0d4051d810e6b0 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 9 Feb 2025 19:51:05 +0100 Subject: [PATCH] cleaned code --- lightrag/base.py | 28 +++++++++++++++------------- lightrag/kg/json_kv_impl.py | 18 +++++++++--------- lightrag/kg/jsondocstatus_impl.py | 2 +- lightrag/kg/mongo_impl.py | 19 ++++++++++--------- lightrag/kg/oracle_impl.py | 24 ++++++++++++++---------- lightrag/kg/postgres_impl.py | 30 ++++++++++++++++-------------- lightrag/kg/redis_impl.py | 4 ++-- lightrag/kg/tidb_impl.py | 21 ++++++++++----------- 8 files changed, 77 insertions(+), 69 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 9b3e5f00..1a7f9c2e 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,24 +1,26 @@ -from enum import Enum import os from dataclasses import dataclass, field +from enum import Enum from typing import ( + Any, + Literal, Optional, TypedDict, - Union, - Literal, TypeVar, - Any, + Union, ) import numpy as np - from .utils import EmbeddingFunc -TextChunkSchema = TypedDict( - "TextChunkSchema", - {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}, -) + +class TextChunkSchema(TypedDict): + tokens: int + content: str + full_doc_id: str + chunk_order_index: int + T = TypeVar("T") @@ -57,11 +59,11 @@ class StorageNameSpace: global_config: dict[str, Any] async def index_done_callback(self): - """commit the storage operations after indexing""" + """Commit the storage operations after indexing""" pass async def query_done_callback(self): - """commit the storage operations after querying""" + """Commit the storage operations after querying""" pass @@ -84,14 +86,14 @@ class BaseVectorStorage(StorageNameSpace): class BaseKVStorage(StorageNameSpace): embedding_func: EmbeddingFunc - async def get_by_id(self, id: str) -> dict[str, Any]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: raise NotImplementedError async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: raise NotImplementedError async def filter_keys(self, data: set[str]) -> set[str]: - """return un-exist keys""" + """Return un-exist keys""" raise NotImplementedError async def upsert(self, data: dict[str, Any]) -> None: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index c61d088d..e545c650 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,16 +1,16 @@ import asyncio import os from dataclasses import dataclass -from typing import Any +from typing import Any, Union -from lightrag.utils import ( - logger, - load_json, - write_json, -) from lightrag.base import ( BaseKVStorage, ) +from lightrag.utils import ( + load_json, + logger, + write_json, +) @dataclass @@ -25,8 +25,8 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id: str) -> dict[str, Any]: - return self._data.get(id, {}) + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: return [ @@ -39,7 +39,7 @@ class JsonKVStorage(BaseKVStorage): ] async def filter_keys(self, data: set[str]) -> set[str]: - return set([s for s in data if s not in self._data]) + return data - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index 179b17a3..2ff06d3a 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -76,7 +76,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return {k for k, _ in self._data.items() if k in data} + return set(k for k in data if k not in self._data) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 1294a26a..4f919ecd 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,8 +1,9 @@ import os -from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass -import pipmaster as pm + import numpy as np +import pipmaster as pm +from tqdm.asyncio import tqdm as tqdm_async if not pm.is_installed("pymongo"): pm.install("pymongo") @@ -10,13 +11,14 @@ if not pm.is_installed("pymongo"): if not pm.is_installed("motor"): pm.install("motor") -from pymongo import MongoClient -from motor.motor_asyncio import AsyncIOMotorClient -from typing import Any, Union, List, Tuple +from typing import Any, List, Tuple, Union -from ..utils import logger -from ..base import BaseKVStorage, BaseGraphStorage +from motor.motor_asyncio import AsyncIOMotorClient +from pymongo import MongoClient + +from ..base import BaseGraphStorage, BaseKVStorage from ..namespace import NameSpace, is_namespace +from ..utils import logger @dataclass @@ -29,7 +31,7 @@ class MongoKVStorage(BaseKVStorage): self._data = database.get_collection(self.namespace) logger.info(f"Use MongoDB as KV {self.namespace}") - async def get_by_id(self, id: str) -> dict[str, Any]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -170,7 +172,6 @@ class MongoGraphStorage(BaseGraphStorage): But typically for a direct edge, we might just do a find_one. Below is a demonstration approach. """ - # We can do a single-hop graphLookup (maxDepth=0 or 1). # Then check if the target_node appears among the edges array. pipeline = [ diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index b648c9bc..ca6bcfb2 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -1,27 +1,28 @@ -import os +import array import asyncio +import os # import html # import os from dataclasses import dataclass from typing import Any, Union + import numpy as np -import array import pipmaster as pm if not pm.is_installed("oracledb"): pm.install("oracledb") -from ..utils import logger +import oracledb + from ..base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, ) from ..namespace import NameSpace, is_namespace - -import oracledb +from ..utils import logger class OracleDB: @@ -107,7 +108,7 @@ class OracleDB: "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only" ) else: - await self.query("SELECT 1 FROM {k}".format(k=k)) + await self.query(f"SELECT 1 FROM {k}") except Exception as e: logger.error(f"Failed to check table {k} in Oracle database") logger.error(f"Oracle database error: {e}") @@ -181,8 +182,8 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> dict[str, Any]: - """get doc_full data based on id.""" + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + """Get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) @@ -191,7 +192,10 @@ class OracleKVStorage(BaseKVStorage): res = {} for row in array_res: res[row["id"]] = row - return res + if res: + return res + else: + return None else: return await self.db.query(SQL, params) @@ -209,7 +213,7 @@ class OracleKVStorage(BaseKVStorage): return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """get doc_chunks data based on id""" + """Get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 63df869e..d319f6f9 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,34 +4,35 @@ import json import os import time from dataclasses import dataclass -from typing import Union, List, Dict, Set, Any, Tuple -import numpy as np +from typing import Any, Dict, List, Set, Tuple, Union +import numpy as np import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") -import asyncpg import sys -from tqdm.asyncio import tqdm as tqdm_async + +import asyncpg from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) +from tqdm.asyncio import tqdm as tqdm_async -from ..utils import logger from ..base import ( + BaseGraphStorage, BaseKVStorage, BaseVectorStorage, - DocStatusStorage, - DocStatus, DocProcessingStatus, - BaseGraphStorage, + DocStatus, + DocStatusStorage, ) from ..namespace import NameSpace, is_namespace +from ..utils import logger if sys.platform.startswith("win"): import asyncio.windows_events @@ -82,7 +83,7 @@ class PostgreSQLDB: async def check_tables(self): for k, v in TABLES.items(): try: - await self.query("SELECT 1 FROM {k} LIMIT 1".format(k=k)) + await self.query(f"SELECT 1 FROM {k} LIMIT 1") except Exception as e: logger.error(f"Failed to check table {k} in PostgreSQL database") logger.error(f"PostgreSQL database error: {e}") @@ -183,7 +184,7 @@ class PGKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> dict[str, Any]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -192,9 +193,10 @@ class PGKVStorage(BaseKVStorage): res = {} for row in array_res: res[row["id"]] = row - return res + return res if res else None else: - return await self.db.query(sql, params) + response = await self.db.query(sql, params) + return response if response else None async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" @@ -435,12 +437,12 @@ class PGDocStatusStorage(DocStatusStorage): existed = set([element["id"] for element in result]) return set(data) - existed - async def get_by_id(self, id: str) -> dict[str, Any]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" params = {"workspace": self.db.workspace, "id": id} result = await self.db.query(sql, params, True) if result is None or result == []: - return {} + return None else: return DocProcessingStatus( content=result[0]["content"], diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index ef95d6db..e97a6afc 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,5 +1,5 @@ import os -from typing import Any +from typing import Any, Union from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -21,7 +21,7 @@ class RedisKVStorage(BaseKVStorage): self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") - async def get_by_id(self, id): + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 1f454639..d9eeb2dd 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -14,12 +14,12 @@ if not pm.is_installed("sqlalchemy"): from sqlalchemy import create_engine, text from tqdm import tqdm -from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage -from ..utils import logger +from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..namespace import NameSpace, is_namespace +from ..utils import logger -class TiDB(object): +class TiDB: def __init__(self, config, **kwargs): self.host = config.get("host", None) self.port = config.get("port", None) @@ -108,12 +108,12 @@ class TiDBKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> dict[str, Any]: + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: """Fetch doc_full data by id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} - # print("get_by_id:"+SQL) - return await self.db.query(SQL, params) + response = await self.db.query(SQL, params) + return response if response else None # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -178,7 +178,7 @@ class TiDBKVStorage(BaseKVStorage): "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], - "content_vector": f"{item['__vector__'].tolist()}", + "content_vector": f'{item["__vector__"].tolist()}', "workspace": self.db.workspace, } ) @@ -222,8 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): ) async def query(self, query: str, top_k: int) -> list[dict]: - """search from tidb vector""" - + """Search from tidb vector""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] @@ -286,7 +285,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): "id": item["id"], "name": item["entity_name"], "content": item["content"], - "content_vector": f"{item['content_vector'].tolist()}", + "content_vector": f'{item["content_vector"].tolist()}', "workspace": self.db.workspace, } # update entity_id if node inserted by graph_storage_instance before @@ -308,7 +307,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): "source_name": item["src_id"], "target_name": item["tgt_id"], "content": item["content"], - "content_vector": f"{item['content_vector'].tolist()}", + "content_vector": f'{item["content_vector"].tolist()}', "workspace": self.db.workspace, } # update relation_id if node inserted by graph_storage_instance before