Fix linting
This commit is contained in:
@@ -1425,7 +1425,6 @@ class LightRAG:
|
|||||||
async def _query_done(self):
|
async def _query_done(self):
|
||||||
await self.llm_response_cache.index_done_callback()
|
await self.llm_response_cache.index_done_callback()
|
||||||
|
|
||||||
|
|
||||||
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
||||||
"""Clear cache data from the LLM response cache storage.
|
"""Clear cache data from the LLM response cache storage.
|
||||||
|
|
||||||
@@ -1479,7 +1478,6 @@ class LightRAG:
|
|||||||
"""Synchronous version of aclear_cache."""
|
"""Synchronous version of aclear_cache."""
|
||||||
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
||||||
|
|
||||||
|
|
||||||
async def get_docs_by_status(
|
async def get_docs_by_status(
|
||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> dict[str, DocProcessingStatus]:
|
) -> dict[str, DocProcessingStatus]:
|
||||||
@@ -1748,7 +1746,6 @@ class LightRAG:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while deleting document {doc_id}: {e}")
|
logger.error(f"Error while deleting document {doc_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def adelete_by_entity(self, entity_name: str) -> None:
|
async def adelete_by_entity(self, entity_name: str) -> None:
|
||||||
"""Asynchronously delete an entity and all its relationships.
|
"""Asynchronously delete an entity and all its relationships.
|
||||||
|
|
||||||
@@ -1756,11 +1753,12 @@ class LightRAG:
|
|||||||
entity_name: Name of the entity to delete
|
entity_name: Name of the entity to delete
|
||||||
"""
|
"""
|
||||||
from .utils_graph import adelete_by_entity
|
from .utils_graph import adelete_by_entity
|
||||||
|
|
||||||
return await adelete_by_entity(
|
return await adelete_by_entity(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
entity_name
|
entity_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_by_entity(self, entity_name: str) -> None:
|
def delete_by_entity(self, entity_name: str) -> None:
|
||||||
@@ -1775,16 +1773,19 @@ class LightRAG:
|
|||||||
target_entity: Name of the target entity
|
target_entity: Name of the target entity
|
||||||
"""
|
"""
|
||||||
from .utils_graph import adelete_by_relation
|
from .utils_graph import adelete_by_relation
|
||||||
|
|
||||||
return await adelete_by_relation(
|
return await adelete_by_relation(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
source_entity,
|
source_entity,
|
||||||
target_entity
|
target_entity,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
def delete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity))
|
return loop.run_until_complete(
|
||||||
|
self.adelete_by_relation(source_entity, target_entity)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_processing_status(self) -> dict[str, int]:
|
async def get_processing_status(self) -> dict[str, int]:
|
||||||
"""Get current document processing status counts
|
"""Get current document processing status counts
|
||||||
@@ -1799,11 +1800,12 @@ class LightRAG:
|
|||||||
) -> dict[str, str | None | dict[str, str]]:
|
) -> dict[str, str | None | dict[str, str]]:
|
||||||
"""Get detailed information of an entity"""
|
"""Get detailed information of an entity"""
|
||||||
from .utils_graph import get_entity_info
|
from .utils_graph import get_entity_info
|
||||||
|
|
||||||
return await get_entity_info(
|
return await get_entity_info(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
entity_name,
|
entity_name,
|
||||||
include_vector_data
|
include_vector_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_relation_info(
|
async def get_relation_info(
|
||||||
@@ -1811,12 +1813,13 @@ class LightRAG:
|
|||||||
) -> dict[str, str | None | dict[str, str]]:
|
) -> dict[str, str | None | dict[str, str]]:
|
||||||
"""Get detailed information of a relationship"""
|
"""Get detailed information of a relationship"""
|
||||||
from .utils_graph import get_relation_info
|
from .utils_graph import get_relation_info
|
||||||
|
|
||||||
return await get_relation_info(
|
return await get_relation_info(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
src_entity,
|
src_entity,
|
||||||
tgt_entity,
|
tgt_entity,
|
||||||
include_vector_data
|
include_vector_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aedit_entity(
|
async def aedit_entity(
|
||||||
@@ -1835,13 +1838,14 @@ class LightRAG:
|
|||||||
Dictionary containing updated entity information
|
Dictionary containing updated entity information
|
||||||
"""
|
"""
|
||||||
from .utils_graph import aedit_entity
|
from .utils_graph import aedit_entity
|
||||||
|
|
||||||
return await aedit_entity(
|
return await aedit_entity(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
entity_name,
|
entity_name,
|
||||||
updated_data,
|
updated_data,
|
||||||
allow_rename
|
allow_rename,
|
||||||
)
|
)
|
||||||
|
|
||||||
def edit_entity(
|
def edit_entity(
|
||||||
@@ -1868,13 +1872,14 @@ class LightRAG:
|
|||||||
Dictionary containing updated relation information
|
Dictionary containing updated relation information
|
||||||
"""
|
"""
|
||||||
from .utils_graph import aedit_relation
|
from .utils_graph import aedit_relation
|
||||||
|
|
||||||
return await aedit_relation(
|
return await aedit_relation(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
source_entity,
|
source_entity,
|
||||||
target_entity,
|
target_entity,
|
||||||
updated_data
|
updated_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def edit_relation(
|
def edit_relation(
|
||||||
@@ -1900,12 +1905,13 @@ class LightRAG:
|
|||||||
Dictionary containing created entity information
|
Dictionary containing created entity information
|
||||||
"""
|
"""
|
||||||
from .utils_graph import acreate_entity
|
from .utils_graph import acreate_entity
|
||||||
|
|
||||||
return await acreate_entity(
|
return await acreate_entity(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
entity_name,
|
entity_name,
|
||||||
entity_data
|
entity_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_entity(
|
def create_entity(
|
||||||
@@ -1930,13 +1936,14 @@ class LightRAG:
|
|||||||
Dictionary containing created relation information
|
Dictionary containing created relation information
|
||||||
"""
|
"""
|
||||||
from .utils_graph import acreate_relation
|
from .utils_graph import acreate_relation
|
||||||
|
|
||||||
return await acreate_relation(
|
return await acreate_relation(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
source_entity,
|
source_entity,
|
||||||
target_entity,
|
target_entity,
|
||||||
relation_data
|
relation_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_relation(
|
def create_relation(
|
||||||
@@ -1975,6 +1982,7 @@ class LightRAG:
|
|||||||
Dictionary containing the merged entity information
|
Dictionary containing the merged entity information
|
||||||
"""
|
"""
|
||||||
from .utils_graph import amerge_entities
|
from .utils_graph import amerge_entities
|
||||||
|
|
||||||
return await amerge_entities(
|
return await amerge_entities(
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
@@ -1982,7 +1990,7 @@ class LightRAG:
|
|||||||
source_entities,
|
source_entities,
|
||||||
target_entity,
|
target_entity,
|
||||||
merge_strategy,
|
merge_strategy,
|
||||||
target_entity_data
|
target_entity_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def merge_entities(
|
def merge_entities(
|
||||||
@@ -2025,7 +2033,7 @@ class LightRAG:
|
|||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
output_path,
|
output_path,
|
||||||
file_format,
|
file_format,
|
||||||
include_vector_data
|
include_vector_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def export_data(
|
def export_data(
|
||||||
|
@@ -942,7 +942,9 @@ async def aexport_data(
|
|||||||
entity_row = {
|
entity_row = {
|
||||||
"entity_name": entity_name,
|
"entity_name": entity_name,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
"graph_data": str(entity_info["graph_data"]), # Convert to string to ensure compatibility
|
"graph_data": str(
|
||||||
|
entity_info["graph_data"]
|
||||||
|
), # Convert to string to ensure compatibility
|
||||||
}
|
}
|
||||||
if include_vector_data and "vector_data" in entity_info:
|
if include_vector_data and "vector_data" in entity_info:
|
||||||
entity_row["vector_data"] = str(entity_info["vector_data"])
|
entity_row["vector_data"] = str(entity_info["vector_data"])
|
||||||
@@ -1010,9 +1012,7 @@ async def aexport_data(
|
|||||||
# Relations
|
# Relations
|
||||||
if relations_data:
|
if relations_data:
|
||||||
csvfile.write("# RELATIONS\n")
|
csvfile.write("# RELATIONS\n")
|
||||||
writer = csv.DictWriter(
|
writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys())
|
||||||
csvfile, fieldnames=relations_data[0].keys()
|
|
||||||
)
|
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
writer.writerows(relations_data)
|
writer.writerows(relations_data)
|
||||||
csvfile.write("\n\n")
|
csvfile.write("\n\n")
|
||||||
@@ -1030,16 +1030,12 @@ async def aexport_data(
|
|||||||
# Excel export
|
# Excel export
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
entities_df = (
|
entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
|
||||||
pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
|
|
||||||
)
|
|
||||||
relations_df = (
|
relations_df = (
|
||||||
pd.DataFrame(relations_data) if relations_data else pd.DataFrame()
|
pd.DataFrame(relations_data) if relations_data else pd.DataFrame()
|
||||||
)
|
)
|
||||||
relationships_df = (
|
relationships_df = (
|
||||||
pd.DataFrame(relationships_data)
|
pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame()
|
||||||
if relationships_data
|
|
||||||
else pd.DataFrame()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
|
with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
|
||||||
@@ -1063,9 +1059,7 @@ async def aexport_data(
|
|||||||
# Write header
|
# Write header
|
||||||
mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n")
|
mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n")
|
||||||
mdfile.write(
|
mdfile.write(
|
||||||
"| "
|
"| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n"
|
||||||
+ " | ".join(["---"] * len(entities_data[0].keys()))
|
|
||||||
+ " |\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Write rows
|
# Write rows
|
||||||
@@ -1083,17 +1077,13 @@ async def aexport_data(
|
|||||||
# Write header
|
# Write header
|
||||||
mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n")
|
mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n")
|
||||||
mdfile.write(
|
mdfile.write(
|
||||||
"| "
|
"| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n"
|
||||||
+ " | ".join(["---"] * len(relations_data[0].keys()))
|
|
||||||
+ " |\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Write rows
|
# Write rows
|
||||||
for relation in relations_data:
|
for relation in relations_data:
|
||||||
mdfile.write(
|
mdfile.write(
|
||||||
"| "
|
"| " + " | ".join(str(v) for v in relation.values()) + " |\n"
|
||||||
+ " | ".join(str(v) for v in relation.values())
|
|
||||||
+ " |\n"
|
|
||||||
)
|
)
|
||||||
mdfile.write("\n\n")
|
mdfile.write("\n\n")
|
||||||
else:
|
else:
|
||||||
@@ -1103,9 +1093,7 @@ async def aexport_data(
|
|||||||
mdfile.write("## Relationships\n\n")
|
mdfile.write("## Relationships\n\n")
|
||||||
if relationships_data:
|
if relationships_data:
|
||||||
# Write header
|
# Write header
|
||||||
mdfile.write(
|
mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n")
|
||||||
"| " + " | ".join(relationships_data[0].keys()) + " |\n"
|
|
||||||
)
|
|
||||||
mdfile.write(
|
mdfile.write(
|
||||||
"| "
|
"| "
|
||||||
+ " | ".join(["---"] * len(relationships_data[0].keys()))
|
+ " | ".join(["---"] * len(relationships_data[0].keys()))
|
||||||
@@ -1160,9 +1148,7 @@ async def aexport_data(
|
|||||||
k: max(len(k), max(len(str(r[k])) for r in relations_data))
|
k: max(len(k), max(len(str(r[k])) for r in relations_data))
|
||||||
for k in relations_data[0]
|
for k in relations_data[0]
|
||||||
}
|
}
|
||||||
header = " ".join(
|
header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0])
|
||||||
k.ljust(col_widths[k]) for k in relations_data[0]
|
|
||||||
)
|
|
||||||
txtfile.write(header + "\n")
|
txtfile.write(header + "\n")
|
||||||
txtfile.write("-" * len(header) + "\n")
|
txtfile.write("-" * len(header) + "\n")
|
||||||
|
|
||||||
@@ -1247,7 +1233,7 @@ def export_data(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
output_path,
|
output_path,
|
||||||
file_format,
|
file_format,
|
||||||
include_vector_data
|
include_vector_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -7,11 +7,9 @@ from .kg.shared_storage import get_graph_db_lock
|
|||||||
from .prompt import GRAPH_FIELD_SEP
|
from .prompt import GRAPH_FIELD_SEP
|
||||||
from .utils import compute_mdhash_id, logger, StorageNameSpace
|
from .utils import compute_mdhash_id, logger, StorageNameSpace
|
||||||
|
|
||||||
|
|
||||||
async def adelete_by_entity(
|
async def adelete_by_entity(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph, entities_vdb, relationships_vdb, entity_name: str
|
||||||
entities_vdb,
|
|
||||||
relationships_vdb,
|
|
||||||
entity_name: str
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Asynchronously delete an entity and all its relationships.
|
"""Asynchronously delete an entity and all its relationships.
|
||||||
|
|
||||||
@@ -32,11 +30,16 @@ async def adelete_by_entity(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Entity '{entity_name}' and its relationships have been deleted."
|
f"Entity '{entity_name}' and its relationships have been deleted."
|
||||||
)
|
)
|
||||||
await _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
await _delete_by_entity_done(
|
||||||
|
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||||||
|
|
||||||
async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
|
|
||||||
|
async def _delete_by_entity_done(
|
||||||
|
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||||
|
) -> None:
|
||||||
"""Callback after entity deletion is complete, ensures updates are persisted"""
|
"""Callback after entity deletion is complete, ensures updates are persisted"""
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
@@ -49,11 +52,12 @@ async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_r
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def adelete_by_relation(
|
async def adelete_by_relation(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
source_entity: str,
|
source_entity: str,
|
||||||
target_entity: str
|
target_entity: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Asynchronously delete a relation between two entities.
|
"""Asynchronously delete a relation between two entities.
|
||||||
|
|
||||||
@@ -97,6 +101,7 @@ async def adelete_by_relation(
|
|||||||
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
|
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
|
async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
|
||||||
"""Callback after relation deletion is complete, ensures updates are persisted"""
|
"""Callback after relation deletion is complete, ensures updates are persisted"""
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
@@ -109,13 +114,14 @@ async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph)
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def aedit_entity(
|
async def aedit_entity(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
entities_vdb,
|
entities_vdb,
|
||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
entity_name: str,
|
entity_name: str,
|
||||||
updated_data: dict[str, str],
|
updated_data: dict[str, str],
|
||||||
allow_rename: bool = True
|
allow_rename: bool = True,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Asynchronously edit entity information.
|
"""Asynchronously edit entity information.
|
||||||
|
|
||||||
@@ -183,9 +189,7 @@ async def aedit_entity(
|
|||||||
relations_to_update = []
|
relations_to_update = []
|
||||||
relations_to_delete = []
|
relations_to_delete = []
|
||||||
# Get all edges related to the original entity
|
# Get all edges related to the original entity
|
||||||
edges = await chunk_entity_relation_graph.get_node_edges(
|
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
|
||||||
entity_name
|
|
||||||
)
|
|
||||||
if edges:
|
if edges:
|
||||||
# Recreate edges for the new entity
|
# Recreate edges for the new entity
|
||||||
for source, target in edges:
|
for source, target in edges:
|
||||||
@@ -291,15 +295,25 @@ async def aedit_entity(
|
|||||||
await entities_vdb.upsert(entity_data)
|
await entities_vdb.upsert(entity_data)
|
||||||
|
|
||||||
# 4. Save changes
|
# 4. Save changes
|
||||||
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
await _edit_entity_done(
|
||||||
|
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Entity '{entity_name}' successfully updated")
|
logger.info(f"Entity '{entity_name}' successfully updated")
|
||||||
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
|
return await get_entity_info(
|
||||||
|
chunk_entity_relation_graph,
|
||||||
|
entities_vdb,
|
||||||
|
entity_name,
|
||||||
|
include_vector_data=True,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while editing entity '{entity_name}': {e}")
|
logger.error(f"Error while editing entity '{entity_name}': {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
|
|
||||||
|
async def _edit_entity_done(
|
||||||
|
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||||
|
) -> None:
|
||||||
"""Callback after entity editing is complete, ensures updates are persisted"""
|
"""Callback after entity editing is complete, ensures updates are persisted"""
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
@@ -312,13 +326,14 @@ async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relati
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def aedit_relation(
|
async def aedit_relation(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
entities_vdb,
|
entities_vdb,
|
||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
source_entity: str,
|
source_entity: str,
|
||||||
target_entity: str,
|
target_entity: str,
|
||||||
updated_data: dict[str, Any]
|
updated_data: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Asynchronously edit relation information.
|
"""Asynchronously edit relation information.
|
||||||
|
|
||||||
@@ -402,7 +417,11 @@ async def aedit_relation(
|
|||||||
f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
|
f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
|
||||||
)
|
)
|
||||||
return await get_relation_info(
|
return await get_relation_info(
|
||||||
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
|
chunk_entity_relation_graph,
|
||||||
|
relationships_vdb,
|
||||||
|
source_entity,
|
||||||
|
target_entity,
|
||||||
|
include_vector_data=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -410,6 +429,7 @@ async def aedit_relation(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
|
async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
|
||||||
"""Callback after relation editing is complete, ensures updates are persisted"""
|
"""Callback after relation editing is complete, ensures updates are persisted"""
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
@@ -422,12 +442,13 @@ async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) ->
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def acreate_entity(
|
async def acreate_entity(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
entities_vdb,
|
entities_vdb,
|
||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
entity_name: str,
|
entity_name: str,
|
||||||
entity_data: dict[str, Any]
|
entity_data: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Asynchronously create a new entity.
|
"""Asynchronously create a new entity.
|
||||||
|
|
||||||
@@ -487,21 +508,29 @@ async def acreate_entity(
|
|||||||
await entities_vdb.upsert(entity_data_for_vdb)
|
await entities_vdb.upsert(entity_data_for_vdb)
|
||||||
|
|
||||||
# Save changes
|
# Save changes
|
||||||
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
await _edit_entity_done(
|
||||||
|
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Entity '{entity_name}' successfully created")
|
logger.info(f"Entity '{entity_name}' successfully created")
|
||||||
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
|
return await get_entity_info(
|
||||||
|
chunk_entity_relation_graph,
|
||||||
|
entities_vdb,
|
||||||
|
entity_name,
|
||||||
|
include_vector_data=True,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while creating entity '{entity_name}': {e}")
|
logger.error(f"Error while creating entity '{entity_name}': {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def acreate_relation(
|
async def acreate_relation(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
entities_vdb,
|
entities_vdb,
|
||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
source_entity: str,
|
source_entity: str,
|
||||||
target_entity: str,
|
target_entity: str,
|
||||||
relation_data: dict[str, Any]
|
relation_data: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Asynchronously create a new relation between entities.
|
"""Asynchronously create a new relation between entities.
|
||||||
|
|
||||||
@@ -523,12 +552,8 @@ async def acreate_relation(
|
|||||||
async with graph_db_lock:
|
async with graph_db_lock:
|
||||||
try:
|
try:
|
||||||
# Check if both entities exist
|
# Check if both entities exist
|
||||||
source_exists = await chunk_entity_relation_graph.has_node(
|
source_exists = await chunk_entity_relation_graph.has_node(source_entity)
|
||||||
source_entity
|
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
|
||||||
)
|
|
||||||
target_exists = await chunk_entity_relation_graph.has_node(
|
|
||||||
target_entity
|
|
||||||
)
|
|
||||||
|
|
||||||
if not source_exists:
|
if not source_exists:
|
||||||
raise ValueError(f"Source entity '{source_entity}' does not exist")
|
raise ValueError(f"Source entity '{source_entity}' does not exist")
|
||||||
@@ -594,7 +619,11 @@ async def acreate_relation(
|
|||||||
f"Relation from '{source_entity}' to '{target_entity}' successfully created"
|
f"Relation from '{source_entity}' to '{target_entity}' successfully created"
|
||||||
)
|
)
|
||||||
return await get_relation_info(
|
return await get_relation_info(
|
||||||
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
|
chunk_entity_relation_graph,
|
||||||
|
relationships_vdb,
|
||||||
|
source_entity,
|
||||||
|
target_entity,
|
||||||
|
include_vector_data=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -602,6 +631,7 @@ async def acreate_relation(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def amerge_entities(
|
async def amerge_entities(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
entities_vdb,
|
entities_vdb,
|
||||||
@@ -657,18 +687,14 @@ async def amerge_entities(
|
|||||||
# 1. Check if all source entities exist
|
# 1. Check if all source entities exist
|
||||||
source_entities_data = {}
|
source_entities_data = {}
|
||||||
for entity_name in source_entities:
|
for entity_name in source_entities:
|
||||||
node_exists = await chunk_entity_relation_graph.has_node(
|
node_exists = await chunk_entity_relation_graph.has_node(entity_name)
|
||||||
entity_name
|
|
||||||
)
|
|
||||||
if not node_exists:
|
if not node_exists:
|
||||||
raise ValueError(f"Source entity '{entity_name}' does not exist")
|
raise ValueError(f"Source entity '{entity_name}' does not exist")
|
||||||
node_data = await chunk_entity_relation_graph.get_node(entity_name)
|
node_data = await chunk_entity_relation_graph.get_node(entity_name)
|
||||||
source_entities_data[entity_name] = node_data
|
source_entities_data[entity_name] = node_data
|
||||||
|
|
||||||
# 2. Check if target entity exists and get its data if it does
|
# 2. Check if target entity exists and get its data if it does
|
||||||
target_exists = await chunk_entity_relation_graph.has_node(
|
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
|
||||||
target_entity
|
|
||||||
)
|
|
||||||
existing_target_entity_data = {}
|
existing_target_entity_data = {}
|
||||||
if target_exists:
|
if target_exists:
|
||||||
existing_target_entity_data = (
|
existing_target_entity_data = (
|
||||||
@@ -693,9 +719,7 @@ async def amerge_entities(
|
|||||||
all_relations = []
|
all_relations = []
|
||||||
for entity_name in source_entities:
|
for entity_name in source_entities:
|
||||||
# Get all relationships of the source entities
|
# Get all relationships of the source entities
|
||||||
edges = await chunk_entity_relation_graph.get_node_edges(
|
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
|
||||||
entity_name
|
|
||||||
)
|
|
||||||
if edges:
|
if edges:
|
||||||
for src, tgt in edges:
|
for src, tgt in edges:
|
||||||
# Ensure src is the current entity
|
# Ensure src is the current entity
|
||||||
@@ -842,17 +866,25 @@ async def amerge_entities(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 10. Save changes
|
# 10. Save changes
|
||||||
await _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
await _merge_entities_done(
|
||||||
|
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully merged {len(source_entities)} entities into '{target_entity}'"
|
f"Successfully merged {len(source_entities)} entities into '{target_entity}'"
|
||||||
)
|
)
|
||||||
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, target_entity, include_vector_data=True)
|
return await get_entity_info(
|
||||||
|
chunk_entity_relation_graph,
|
||||||
|
entities_vdb,
|
||||||
|
target_entity,
|
||||||
|
include_vector_data=True,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error merging entities: {e}")
|
logger.error(f"Error merging entities: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _merge_entity_attributes(
|
def _merge_entity_attributes(
|
||||||
entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
|
entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@@ -902,6 +934,7 @@ def _merge_entity_attributes(
|
|||||||
|
|
||||||
return merged_data
|
return merged_data
|
||||||
|
|
||||||
|
|
||||||
def _merge_relation_attributes(
|
def _merge_relation_attributes(
|
||||||
relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
|
relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@@ -925,9 +958,7 @@ def _merge_relation_attributes(
|
|||||||
for key in all_keys:
|
for key in all_keys:
|
||||||
# Get all values for this key
|
# Get all values for this key
|
||||||
values = [
|
values = [
|
||||||
data.get(key)
|
data.get(key) for data in relation_data_list if data.get(key) is not None
|
||||||
for data in relation_data_list
|
|
||||||
if data.get(key) is not None
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if not values:
|
if not values:
|
||||||
@@ -961,7 +992,10 @@ def _merge_relation_attributes(
|
|||||||
|
|
||||||
return merged_data
|
return merged_data
|
||||||
|
|
||||||
async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
|
|
||||||
|
async def _merge_entities_done(
|
||||||
|
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||||
|
) -> None:
|
||||||
"""Callback after entity merging is complete, ensures updates are persisted"""
|
"""Callback after entity merging is complete, ensures updates are persisted"""
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
@@ -974,11 +1008,12 @@ async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_rel
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_entity_info(
|
async def get_entity_info(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
entities_vdb,
|
entities_vdb,
|
||||||
entity_name: str,
|
entity_name: str,
|
||||||
include_vector_data: bool = False
|
include_vector_data: bool = False,
|
||||||
) -> dict[str, str | None | dict[str, str]]:
|
) -> dict[str, str | None | dict[str, str]]:
|
||||||
"""Get detailed information of an entity"""
|
"""Get detailed information of an entity"""
|
||||||
|
|
||||||
@@ -1000,19 +1035,18 @@ async def get_entity_info(
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def get_relation_info(
|
async def get_relation_info(
|
||||||
chunk_entity_relation_graph,
|
chunk_entity_relation_graph,
|
||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
src_entity: str,
|
src_entity: str,
|
||||||
tgt_entity: str,
|
tgt_entity: str,
|
||||||
include_vector_data: bool = False
|
include_vector_data: bool = False,
|
||||||
) -> dict[str, str | None | dict[str, str]]:
|
) -> dict[str, str | None | dict[str, str]]:
|
||||||
"""Get detailed information of a relationship"""
|
"""Get detailed information of a relationship"""
|
||||||
|
|
||||||
# Get information from the graph
|
# Get information from the graph
|
||||||
edge_data = await chunk_entity_relation_graph.get_edge(
|
edge_data = await chunk_entity_relation_graph.get_edge(src_entity, tgt_entity)
|
||||||
src_entity, tgt_entity
|
|
||||||
)
|
|
||||||
source_id = edge_data.get("source_id") if edge_data else None
|
source_id = edge_data.get("source_id") if edge_data else None
|
||||||
|
|
||||||
result: dict[str, str | None | dict[str, str]] = {
|
result: dict[str, str | None | dict[str, str]] = {
|
||||||
|
Reference in New Issue
Block a user