From 805da7b95b75d0b955913633e97f23ab1c64b0e8 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 00:01:21 +0100 Subject: [PATCH] cleaned code --- lightrag/base.py | 8 +++- lightrag/kg/faiss_impl.py | 2 +- lightrag/kg/nano_vector_db_impl.py | 2 +- lightrag/lightrag.py | 68 ++++++------------------------ 4 files changed, 21 insertions(+), 59 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 2b3e5cad..29335494 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -106,10 +106,16 @@ class BaseVectorStorage(StorageNameSpace): """ raise NotImplementedError + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError + @dataclass class BaseKVStorage(StorageNameSpace): - embedding_func: EmbeddingFunc + embedding_func: EmbeddingFunc | None = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: raise NotImplementedError diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 0dca9e4c..9a5f7e4e 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") await self.delete([entity_id]) - async def delete_entity_relation(self, entity_name: str): + async def delete_entity_relation(self, entity_name: str) -> None: """ Delete relations for a given entity by scanning metadata. """ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 2db8f72a..5d786646 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") - async def delete_entity_relation(self, entity_name: str): + async def delete_entity_relation(self, entity_name: str) -> None: try: relations = [ dp diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 08855d60..ce86e938 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1095,7 +1095,7 @@ class LightRAG: async def aquery_with_separate_keyword_extraction( self, query: str, prompt: str, param: QueryParam = QueryParam() - ): + ) -> str | AsyncIterator[str]: """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. @@ -1196,12 +1196,7 @@ class LightRAG: return response async def _query_done(self): - tasks = [] - for storage_inst in [self.llm_response_cache]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) - await asyncio.gather(*tasks) + await self.llm_response_cache.index_done_callback() def delete_by_entity(self, entity_name: str): loop = always_get_an_event_loop() @@ -1223,16 +1218,16 @@ class LightRAG: logger.error(f"Error while deleting entity '{entity_name}': {e}") async def _delete_by_entity_done(self): - tasks = [] - for storage_inst in [ - self.entities_vdb, - self.relationships_vdb, - self.chunk_entity_relation_graph, - ]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) - await asyncio.gather(*tasks) + await asyncio.gather( + *[ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.entities_vdb, + self.relationships_vdb, + self.chunk_entity_relation_graph, + ] + ] + ) def _get_content_summary(self, content: str, max_length: int = 100) -> str: """Get summary of document content @@ -1444,10 +1439,6 @@ class LightRAG: except Exception as e: logger.error(f"Error while deleting document {doc_id}: {e}") - def delete_by_doc_id(self, doc_id: str): - """Synchronous version of adelete""" - return asyncio.run(self.adelete_by_doc_id(doc_id)) - async def get_entity_info( self, entity_name: str, include_vector_data: bool = False ) -> dict[str, str | None | dict[str, str]]: @@ -1484,21 +1475,6 @@ class LightRAG: return result - def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False): - """Synchronous version of getting entity information - - Args: - entity_name: Entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - """ - try: - import tracemalloc - - tracemalloc.start() - return asyncio.run(self.get_entity_info(entity_name, include_vector_data)) - finally: - tracemalloc.stop() - async def get_relation_info( self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): @@ -1540,23 +1516,3 @@ class LightRAG: result["vector_data"] = vector_data[0] if vector_data else None return result - - def get_relation_info_sync( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False - ): - """Synchronous version of getting relationship information - - Args: - src_entity: Source entity name (no need for quotes) - tgt_entity: Target entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - """ - try: - import tracemalloc - - tracemalloc.start() - return asyncio.run( - self.get_relation_info(src_entity, tgt_entity, include_vector_data) - ) - finally: - tracemalloc.stop()