cleaned code

This commit is contained in:
Yannick Stephan
2025-02-15 00:01:21 +01:00
parent 7e526d3436
commit 805da7b95b
4 changed files with 21 additions and 59 deletions

View File

@@ -106,10 +106,16 @@ class BaseVectorStorage(StorageNameSpace):
"""
raise NotImplementedError
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError
@dataclass
class BaseKVStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
embedding_func: EmbeddingFunc | None = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
raise NotImplementedError

View File

@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
await self.delete([entity_id])
async def delete_entity_relation(self, entity_name: str):
async def delete_entity_relation(self, entity_name: str) -> None:
"""
Delete relations for a given entity by scanning metadata.
"""

View File

@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str):
async def delete_entity_relation(self, entity_name: str) -> None:
try:
relations = [
dp

View File

@@ -1095,7 +1095,7 @@ class LightRAG:
async def aquery_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
) -> str | AsyncIterator[str]:
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
@@ -1196,12 +1196,7 @@ class LightRAG:
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await self.llm_response_cache.index_done_callback()
def delete_by_entity(self, entity_name: str):
loop = always_get_an_event_loop()
@@ -1223,16 +1218,16 @@ class LightRAG:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(self):
tasks = []
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]
]
)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
@@ -1444,10 +1439,6 @@ class LightRAG:
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
def delete_by_doc_id(self, doc_id: str):
"""Synchronous version of adelete"""
return asyncio.run(self.adelete_by_doc_id(doc_id))
async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False
) -> dict[str, str | None | dict[str, str]]:
@@ -1484,21 +1475,6 @@ class LightRAG:
return result
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
"""Synchronous version of getting entity information
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
tracemalloc.start()
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
finally:
tracemalloc.stop()
async def get_relation_info(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
@@ -1540,23 +1516,3 @@ class LightRAG:
result["vector_data"] = vector_data[0] if vector_data else None
return result
def get_relation_info_sync(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
"""Synchronous version of getting relationship information
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
tracemalloc.start()
return asyncio.run(
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
)
finally:
tracemalloc.stop()