Fix linting

This commit is contained in:
yangdx
2025-04-14 12:08:56 +08:00
parent 40240bc79e
commit 5c1d4201f9
3 changed files with 136 additions and 108 deletions

View File

@@ -1425,7 +1425,6 @@ class LightRAG:
async def _query_done(self):
await self.llm_response_cache.index_done_callback()
async def aclear_cache(self, modes: list[str] | None = None) -> None:
"""Clear cache data from the LLM response cache storage.
@@ -1479,7 +1478,6 @@ class LightRAG:
"""Synchronous version of aclear_cache."""
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
@@ -1748,7 +1746,6 @@ class LightRAG:
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
async def adelete_by_entity(self, entity_name: str) -> None:
"""Asynchronously delete an entity and all its relationships.
@@ -1756,11 +1753,12 @@ class LightRAG:
entity_name: Name of the entity to delete
"""
from .utils_graph import adelete_by_entity
return await adelete_by_entity(
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
entity_name
entity_name,
)
def delete_by_entity(self, entity_name: str) -> None:
@@ -1775,16 +1773,19 @@ class LightRAG:
target_entity: Name of the target entity
"""
from .utils_graph import adelete_by_relation
return await adelete_by_relation(
self.chunk_entity_relation_graph,
self.relationships_vdb,
source_entity,
target_entity
target_entity,
)
def delete_by_relation(self, source_entity: str, target_entity: str) -> None:
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity))
return loop.run_until_complete(
self.adelete_by_relation(source_entity, target_entity)
)
async def get_processing_status(self) -> dict[str, int]:
"""Get current document processing status counts
@@ -1799,11 +1800,12 @@ class LightRAG:
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity"""
from .utils_graph import get_entity_info
return await get_entity_info(
self.chunk_entity_relation_graph,
self.entities_vdb,
entity_name,
include_vector_data
include_vector_data,
)
async def get_relation_info(
@@ -1811,12 +1813,13 @@ class LightRAG:
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of a relationship"""
from .utils_graph import get_relation_info
return await get_relation_info(
self.chunk_entity_relation_graph,
self.relationships_vdb,
src_entity,
tgt_entity,
include_vector_data
include_vector_data,
)
async def aedit_entity(
@@ -1835,13 +1838,14 @@ class LightRAG:
Dictionary containing updated entity information
"""
from .utils_graph import aedit_entity
return await aedit_entity(
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
entity_name,
updated_data,
allow_rename
allow_rename,
)
def edit_entity(
@@ -1868,13 +1872,14 @@ class LightRAG:
Dictionary containing updated relation information
"""
from .utils_graph import aedit_relation
return await aedit_relation(
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
source_entity,
target_entity,
updated_data
updated_data,
)
def edit_relation(
@@ -1900,12 +1905,13 @@ class LightRAG:
Dictionary containing created entity information
"""
from .utils_graph import acreate_entity
return await acreate_entity(
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
entity_name,
entity_data
entity_data,
)
def create_entity(
@@ -1930,13 +1936,14 @@ class LightRAG:
Dictionary containing created relation information
"""
from .utils_graph import acreate_relation
return await acreate_relation(
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
source_entity,
target_entity,
relation_data
relation_data,
)
def create_relation(
@@ -1975,6 +1982,7 @@ class LightRAG:
Dictionary containing the merged entity information
"""
from .utils_graph import amerge_entities
return await amerge_entities(
self.chunk_entity_relation_graph,
self.entities_vdb,
@@ -1982,7 +1990,7 @@ class LightRAG:
source_entities,
target_entity,
merge_strategy,
target_entity_data
target_entity_data,
)
def merge_entities(
@@ -2025,7 +2033,7 @@ class LightRAG:
self.relationships_vdb,
output_path,
file_format,
include_vector_data
include_vector_data,
)
def export_data(

View File

@@ -942,7 +942,9 @@ async def aexport_data(
entity_row = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": str(entity_info["graph_data"]), # Convert to string to ensure compatibility
"graph_data": str(
entity_info["graph_data"]
), # Convert to string to ensure compatibility
}
if include_vector_data and "vector_data" in entity_info:
entity_row["vector_data"] = str(entity_info["vector_data"])
@@ -1010,9 +1012,7 @@ async def aexport_data(
# Relations
if relations_data:
csvfile.write("# RELATIONS\n")
writer = csv.DictWriter(
csvfile, fieldnames=relations_data[0].keys()
)
writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys())
writer.writeheader()
writer.writerows(relations_data)
csvfile.write("\n\n")
@@ -1030,16 +1030,12 @@ async def aexport_data(
# Excel export
import pandas as pd
entities_df = (
pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
)
entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
relations_df = (
pd.DataFrame(relations_data) if relations_data else pd.DataFrame()
)
relationships_df = (
pd.DataFrame(relationships_data)
if relationships_data
else pd.DataFrame()
pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame()
)
with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
@@ -1063,9 +1059,7 @@ async def aexport_data(
# Write header
mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n")
mdfile.write(
"| "
+ " | ".join(["---"] * len(entities_data[0].keys()))
+ " |\n"
"| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n"
)
# Write rows
@@ -1083,17 +1077,13 @@ async def aexport_data(
# Write header
mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n")
mdfile.write(
"| "
+ " | ".join(["---"] * len(relations_data[0].keys()))
+ " |\n"
"| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n"
)
# Write rows
for relation in relations_data:
mdfile.write(
"| "
+ " | ".join(str(v) for v in relation.values())
+ " |\n"
"| " + " | ".join(str(v) for v in relation.values()) + " |\n"
)
mdfile.write("\n\n")
else:
@@ -1103,9 +1093,7 @@ async def aexport_data(
mdfile.write("## Relationships\n\n")
if relationships_data:
# Write header
mdfile.write(
"| " + " | ".join(relationships_data[0].keys()) + " |\n"
)
mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n")
mdfile.write(
"| "
+ " | ".join(["---"] * len(relationships_data[0].keys()))
@@ -1160,9 +1148,7 @@ async def aexport_data(
k: max(len(k), max(len(str(r[k])) for r in relations_data))
for k in relations_data[0]
}
header = " ".join(
k.ljust(col_widths[k]) for k in relations_data[0]
)
header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0])
txtfile.write(header + "\n")
txtfile.write("-" * len(header) + "\n")
@@ -1247,7 +1233,7 @@ def export_data(
relationships_vdb,
output_path,
file_format,
include_vector_data
include_vector_data,
)
)

View File

@@ -7,11 +7,9 @@ from .kg.shared_storage import get_graph_db_lock
from .prompt import GRAPH_FIELD_SEP
from .utils import compute_mdhash_id, logger, StorageNameSpace
async def adelete_by_entity(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
entity_name: str
chunk_entity_relation_graph, entities_vdb, relationships_vdb, entity_name: str
) -> None:
"""Asynchronously delete an entity and all its relationships.
@@ -32,11 +30,16 @@ async def adelete_by_entity(
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
await _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _delete_by_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
except Exception as e:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _delete_by_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity deletion is complete, ensures updates are persisted"""
await asyncio.gather(
*[
@@ -49,11 +52,12 @@ async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_r
]
)
async def adelete_by_relation(
chunk_entity_relation_graph,
relationships_vdb,
source_entity: str,
target_entity: str
target_entity: str,
) -> None:
"""Asynchronously delete a relation between two entities.
@@ -97,6 +101,7 @@ async def adelete_by_relation(
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
)
async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
"""Callback after relation deletion is complete, ensures updates are persisted"""
await asyncio.gather(
@@ -109,13 +114,14 @@ async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph)
]
)
async def aedit_entity(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
entity_name: str,
updated_data: dict[str, str],
allow_rename: bool = True
allow_rename: bool = True,
) -> dict[str, Any]:
"""Asynchronously edit entity information.
@@ -183,9 +189,7 @@ async def aedit_entity(
relations_to_update = []
relations_to_delete = []
# Get all edges related to the original entity
edges = await chunk_entity_relation_graph.get_node_edges(
entity_name
)
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
if edges:
# Recreate edges for the new entity
for source, target in edges:
@@ -291,15 +295,25 @@ async def aedit_entity(
await entities_vdb.upsert(entity_data)
# 4. Save changes
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(f"Entity '{entity_name}' successfully updated")
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name,
include_vector_data=True,
)
except Exception as e:
logger.error(f"Error while editing entity '{entity_name}': {e}")
raise
async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity editing is complete, ensures updates are persisted"""
await asyncio.gather(
*[
@@ -312,13 +326,14 @@ async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relati
]
)
async def aedit_relation(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
source_entity: str,
target_entity: str,
updated_data: dict[str, Any]
updated_data: dict[str, Any],
) -> dict[str, Any]:
"""Asynchronously edit relation information.
@@ -402,7 +417,11 @@ async def aedit_relation(
f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
)
return await get_relation_info(
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
chunk_entity_relation_graph,
relationships_vdb,
source_entity,
target_entity,
include_vector_data=True,
)
except Exception as e:
logger.error(
@@ -410,6 +429,7 @@ async def aedit_relation(
)
raise
async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
"""Callback after relation editing is complete, ensures updates are persisted"""
await asyncio.gather(
@@ -422,12 +442,13 @@ async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) ->
]
)
async def acreate_entity(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
entity_name: str,
entity_data: dict[str, Any]
entity_data: dict[str, Any],
) -> dict[str, Any]:
"""Asynchronously create a new entity.
@@ -487,21 +508,29 @@ async def acreate_entity(
await entities_vdb.upsert(entity_data_for_vdb)
# Save changes
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(f"Entity '{entity_name}' successfully created")
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name,
include_vector_data=True,
)
except Exception as e:
logger.error(f"Error while creating entity '{entity_name}': {e}")
raise
async def acreate_relation(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
source_entity: str,
target_entity: str,
relation_data: dict[str, Any]
relation_data: dict[str, Any],
) -> dict[str, Any]:
"""Asynchronously create a new relation between entities.
@@ -523,12 +552,8 @@ async def acreate_relation(
async with graph_db_lock:
try:
# Check if both entities exist
source_exists = await chunk_entity_relation_graph.has_node(
source_entity
)
target_exists = await chunk_entity_relation_graph.has_node(
target_entity
)
source_exists = await chunk_entity_relation_graph.has_node(source_entity)
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
if not source_exists:
raise ValueError(f"Source entity '{source_entity}' does not exist")
@@ -594,7 +619,11 @@ async def acreate_relation(
f"Relation from '{source_entity}' to '{target_entity}' successfully created"
)
return await get_relation_info(
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
chunk_entity_relation_graph,
relationships_vdb,
source_entity,
target_entity,
include_vector_data=True,
)
except Exception as e:
logger.error(
@@ -602,6 +631,7 @@ async def acreate_relation(
)
raise
async def amerge_entities(
chunk_entity_relation_graph,
entities_vdb,
@@ -657,18 +687,14 @@ async def amerge_entities(
# 1. Check if all source entities exist
source_entities_data = {}
for entity_name in source_entities:
node_exists = await chunk_entity_relation_graph.has_node(
entity_name
)
node_exists = await chunk_entity_relation_graph.has_node(entity_name)
if not node_exists:
raise ValueError(f"Source entity '{entity_name}' does not exist")
node_data = await chunk_entity_relation_graph.get_node(entity_name)
source_entities_data[entity_name] = node_data
# 2. Check if target entity exists and get its data if it does
target_exists = await chunk_entity_relation_graph.has_node(
target_entity
)
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
existing_target_entity_data = {}
if target_exists:
existing_target_entity_data = (
@@ -693,9 +719,7 @@ async def amerge_entities(
all_relations = []
for entity_name in source_entities:
# Get all relationships of the source entities
edges = await chunk_entity_relation_graph.get_node_edges(
entity_name
)
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
if edges:
for src, tgt in edges:
# Ensure src is the current entity
@@ -842,17 +866,25 @@ async def amerge_entities(
)
# 10. Save changes
await _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _merge_entities_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(
f"Successfully merged {len(source_entities)} entities into '{target_entity}'"
)
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, target_entity, include_vector_data=True)
return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
target_entity,
include_vector_data=True,
)
except Exception as e:
logger.error(f"Error merging entities: {e}")
raise
def _merge_entity_attributes(
entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]:
@@ -902,6 +934,7 @@ def _merge_entity_attributes(
return merged_data
def _merge_relation_attributes(
relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]:
@@ -925,9 +958,7 @@ def _merge_relation_attributes(
for key in all_keys:
# Get all values for this key
values = [
data.get(key)
for data in relation_data_list
if data.get(key) is not None
data.get(key) for data in relation_data_list if data.get(key) is not None
]
if not values:
@@ -961,7 +992,10 @@ def _merge_relation_attributes(
return merged_data
async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _merge_entities_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity merging is complete, ensures updates are persisted"""
await asyncio.gather(
*[
@@ -974,11 +1008,12 @@ async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_rel
]
)
async def get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name: str,
include_vector_data: bool = False
include_vector_data: bool = False,
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity"""
@@ -1000,19 +1035,18 @@ async def get_entity_info(
return result
async def get_relation_info(
chunk_entity_relation_graph,
relationships_vdb,
src_entity: str,
tgt_entity: str,
include_vector_data: bool = False
include_vector_data: bool = False,
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of a relationship"""
# Get information from the graph
edge_data = await chunk_entity_relation_graph.get_edge(
src_entity, tgt_entity
)
edge_data = await chunk_entity_relation_graph.get_edge(src_entity, tgt_entity)
source_id = edge_data.get("source_id") if edge_data else None
result: dict[str, str | None | dict[str, str]] = {