From 4cce14e65ea124a9a087ad1495dd38b0cd3a03c1 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 9 Feb 2025 11:24:08 +0100 Subject: [PATCH] cleaned import --- lightrag/base.py | 20 ++++---- lightrag/kg/json_kv_impl.py | 4 +- lightrag/kg/mongo_impl.py | 4 +- lightrag/kg/oracle_impl.py | 4 +- lightrag/kg/postgres_impl.py | 4 +- lightrag/kg/redis_impl.py | 4 +- lightrag/kg/tidb_impl.py | 4 +- lightrag/lightrag.py | 88 ++++++++++++++++++------------------ 8 files changed, 62 insertions(+), 70 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index a91595b2..4b963b43 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,6 +1,8 @@ +from enum import Enum import os from dataclasses import dataclass, field from typing import ( + Optional, TypedDict, Union, Literal, @@ -8,6 +10,8 @@ from typing import ( Any, ) +import numpy as np + from .utils import EmbeddingFunc @@ -99,9 +103,7 @@ class BaseKVStorage(StorageNameSpace): async def drop(self) -> None: raise NotImplementedError - async def get_by_status( - self, status: str - ) -> Union[list[dict[str, Any]], None]: + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: raise NotImplementedError @@ -148,12 +150,12 @@ class BaseGraphStorage(StorageNameSpace): async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") - async def get_all_labels(self) -> List[str]: + async def get_all_labels(self) -> list[str]: raise NotImplementedError async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 - ) -> Dict[str, List[Dict]]: + ) -> dict[str, list[dict]]: raise NotImplementedError @@ -177,20 +179,20 @@ class DocProcessingStatus: updated_at: str # ISO format timestamp chunks_count: Optional[int] = None # Number of chunks after splitting error: Optional[str] = None # Error message if failed - metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata + metadata: dict[str, Any] = field(default_factory=dict) # Additional metadata class DocStatusStorage(BaseKVStorage): """Base class for document status storage""" - async def get_status_counts(self) -> Dict[str, int]: + async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" raise NotImplementedError - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: """Get all failed documents""" raise NotImplementedError - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: """Get all pending documents""" raise NotImplementedError diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 59da1b54..e9225375 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -51,8 +51,6 @@ class JsonKVStorage(BaseKVStorage): async def drop(self) -> None: self._data = {} - async def get_by_status( - self, status: str - ) -> Union[list[dict[str, Any]], None]: + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: result = [v for _, v in self._data.items() if v["status"] == status] return result if result else None diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index eb896b63..b7b438bd 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -77,9 +77,7 @@ class MongoKVStorage(BaseKVStorage): """Drop the collection""" await self._data.drop() - async def get_by_status( - self, status: str - ) -> Union[list[dict[str, Any]], None]: + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: """Get documents by status and ids""" return self._data.find({"status": status}) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 0e55194d..c82db9a6 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -229,9 +229,7 @@ class OracleKVStorage(BaseKVStorage): res = [{k: v} for k, v in dict_res.items()] return res - async def get_by_status( - self, status: str - ) -> Union[list[dict[str, Any]], None]: + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index d966fd85..01e3688a 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -231,9 +231,7 @@ class PGKVStorage(BaseKVStorage): else: return await self.db.query(sql, params, multirows=True) - async def get_by_status( - self, status: str - ) -> Union[list[dict[str, Any]], None]: + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 7c5c7030..f9283dda 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -59,9 +59,7 @@ class RedisKVStorage(BaseKVStorage): if keys: await self._redis.delete(*keys) - async def get_by_status( - self, status: str - ) -> Union[list[dict[str, Any]], None]: + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: pipe = self._redis.pipeline() for key in await self._redis.keys(f"{self.namespace}:*"): pipe.hgetall(key) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 55dbe303..1f454639 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -322,9 +322,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): merge_sql = SQL_TEMPLATES["insert_relationship"] await self.db.execute(merge_sql, data) - async def get_by_status( - self, status: str - ) -> Union[list[dict[str, Any]], None]: + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 69dd85e9..87018b53 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -4,11 +4,16 @@ from tqdm.asyncio import tqdm as tqdm_async from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, Type, Union +from typing import Any, Type, Union, cast import traceback from .operate import ( chunking_by_token_size, - extract_entities + extract_entities, + extract_keywords_only, + kg_query, + kg_query_with_keywords, + mix_kg_vector_query, + naive_query, # local_query,global_query,hybrid_query,, ) @@ -19,18 +24,21 @@ from .utils import ( convert_response_to_json, logger, set_logger, - statistic_data + statistic_data, ) from .base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, DocStatus, + QueryParam, + StorageNameSpace, ) from .namespace import NameSpace, make_namespace from .prompt import GRAPH_FIELD_SEP + STORAGES = { "NetworkXStorage": ".kg.networkx_impl", "JsonKVStorage": ".kg.json_kv_impl", @@ -351,9 +359,10 @@ class LightRAG: ) async def ainsert( - self, string_or_strings: Union[str, list[str]], - split_by_character: str | None = None, - split_by_character_only: bool = False + self, + string_or_strings: Union[str, list[str]], + split_by_character: str | None = None, + split_by_character_only: bool = False, ): """Insert documents with checkpoint support @@ -368,7 +377,6 @@ class LightRAG: await self.apipeline_process_chunks(split_by_character, split_by_character_only) await self.apipeline_process_extract_graph() - def insert_custom_chunks(self, full_text: str, text_chunks: list[str]): loop = always_get_an_event_loop() return loop.run_until_complete( @@ -482,31 +490,27 @@ class LightRAG: logger.info(f"Stored {len(new_docs)} new unique documents") async def apipeline_process_chunks( - self, - split_by_character: str | None = None, - split_by_character_only: bool = False - ) -> None: + self, + split_by_character: str | None = None, + split_by_character_only: bool = False, + ) -> None: """Get pendding documents, split into chunks,insert chunks""" # 1. get all pending and failed documents to_process_doc_keys: list[str] = [] # Process failes - to_process_docs = await self.full_docs.get_by_status( - status=DocStatus.FAILED - ) + to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED) if to_process_docs: to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) - + # Process Pending - to_process_docs = await self.full_docs.get_by_status( - status=DocStatus.PENDING - ) + to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING) if to_process_docs: to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) if not to_process_doc_keys: logger.info("All documents have been processed or are duplicates") - return + return full_docs_ids = await self.full_docs.get_by_ids(to_process_doc_keys) new_docs = {} @@ -515,8 +519,8 @@ class LightRAG: if not new_docs: logger.info("All documents have been processed or are duplicates") - return - + return + # 2. split docs into chunks, insert chunks, update doc status batch_size = self.addon_params.get("insert_batch_size", 10) for i in range(0, len(new_docs), batch_size): @@ -526,11 +530,11 @@ class LightRAG: batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}" ): doc_status: dict[str, Any] = { - "content_summary": doc["content_summary"], - "content_length": doc["content_length"], - "status": DocStatus.PROCESSING, - "created_at": doc["created_at"], - "updated_at": datetime.now().isoformat(), + "content_summary": doc["content_summary"], + "content_length": doc["content_length"], + "status": DocStatus.PROCESSING, + "created_at": doc["created_at"], + "updated_at": datetime.now().isoformat(), } try: await self.doc_status.upsert({doc_id: doc_status}) @@ -564,14 +568,16 @@ class LightRAG: except Exception as e: doc_status.update( - { - "status": DocStatus.FAILED, - "error": str(e), - "updated_at": datetime.now().isoformat(), - } - ) + { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) - logger.error(f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}") + logger.error( + f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + ) continue async def apipeline_process_extract_graph(self): @@ -580,22 +586,18 @@ class LightRAG: to_process_doc_keys: list[str] = [] # Process failes - to_process_docs = await self.full_docs.get_by_status( - status=DocStatus.FAILED - ) + to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED) if to_process_docs: to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) - + # Process Pending - to_process_docs = await self.full_docs.get_by_status( - status=DocStatus.PENDING - ) + to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING) if to_process_docs: to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) if not to_process_doc_keys: logger.info("All documents have been processed or are duplicates") - return + return # Process documents in batches batch_size = self.addon_params.get("insert_batch_size", 10) @@ -606,7 +608,7 @@ class LightRAG: async def process_chunk(chunk_id: str): async with semaphore: - chunks:dict[str, Any] = { + chunks: dict[str, Any] = { i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) } # Extract and store entities and relationships @@ -1051,7 +1053,7 @@ class LightRAG: return content return content[:max_length] + "..." - async def get_processing_status(self) -> Dict[str, int]: + async def get_processing_status(self) -> dict[str, int]: """Get current document processing status counts Returns: