Merge branch 'main' into validate-content-before-enqueue
This commit is contained in:
@@ -16,12 +16,32 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
|
||||
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
|
||||
async def get_graph_labels():
|
||||
"""Get all graph labels"""
|
||||
"""
|
||||
Get all graph labels
|
||||
|
||||
Returns:
|
||||
List[str]: List of graph labels
|
||||
"""
|
||||
return await rag.get_graph_labels()
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
||||
async def get_knowledge_graph(label: str, max_depth: int = 3):
|
||||
"""Get knowledge graph for a specific label"""
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||
|
||||
Args:
|
||||
label (str): Label to get knowledge graph for
|
||||
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Knowledge graph for label
|
||||
"""
|
||||
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
|
||||
|
||||
return router
|
||||
|
@@ -23,7 +23,7 @@ import pipmaster as pm
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
|
||||
from neo4j import (
|
||||
from neo4j import ( # type: ignore
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
@@ -34,6 +34,9 @@ from neo4j import (
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -470,40 +473,61 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence (nodes containing the specified label string)
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
|
||||
Key fixes:
|
||||
1. Include the starting node itself
|
||||
2. Handle multi-label nodes
|
||||
3. Clarify relationship directions
|
||||
4. Add depth control
|
||||
Args:
|
||||
node_label (str): String to match in node labels (will match any node containing this string in its label)
|
||||
max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
|
||||
Returns:
|
||||
KnowledgeGraph: Complete connected subgraph for specified node
|
||||
"""
|
||||
label = node_label.strip('"')
|
||||
# Escape single quotes to prevent injection attacks
|
||||
escaped_label = label.replace("'", "\\'")
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
try:
|
||||
main_query = ""
|
||||
if label == "*":
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
WITH collect(DISTINCT n) AS nodes
|
||||
MATCH ()-[r]-()
|
||||
RETURN nodes, collect(DISTINCT r) AS relationships;
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
WITH n, count(r) AS degree
|
||||
ORDER BY degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect(n) AS nodes
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
RETURN nodes, collect(DISTINCT r) AS relationships
|
||||
"""
|
||||
result_set = await session.run(
|
||||
main_query, {"max_nodes": MAX_GRAPH_NODES}
|
||||
)
|
||||
|
||||
else:
|
||||
# Critical debug step: first verify if starting node exists
|
||||
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
||||
validate_query = f"""
|
||||
MATCH (n)
|
||||
WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}')
|
||||
RETURN n LIMIT 1
|
||||
"""
|
||||
validate_result = await session.run(validate_query)
|
||||
if not await validate_result.single():
|
||||
logger.warning(f"Starting node {label} does not exist!")
|
||||
logger.warning(
|
||||
f"No nodes containing '{label}' in their labels found!"
|
||||
)
|
||||
return result
|
||||
|
||||
# Optimized query (including direction handling and self-loops)
|
||||
# Main query uses partial matching
|
||||
main_query = f"""
|
||||
MATCH (start:`{label}`)
|
||||
MATCH (start)
|
||||
WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}')
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {{
|
||||
relationshipFilter: '>',
|
||||
@@ -512,9 +536,25 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
bfs: true
|
||||
}})
|
||||
YIELD nodes, relationships
|
||||
RETURN nodes, relationships
|
||||
WITH start, nodes, relationships
|
||||
UNWIND nodes AS node
|
||||
OPTIONAL MATCH (node)-[r]-()
|
||||
WITH node, count(r) AS degree, start, nodes, relationships,
|
||||
CASE
|
||||
WHEN id(node) = id(start) THEN 2
|
||||
WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
|
||||
ELSE 0
|
||||
END AS priority
|
||||
ORDER BY priority DESC, degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect(node) AS filtered_nodes, nodes, relationships
|
||||
RETURN filtered_nodes AS nodes,
|
||||
[rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
|
||||
"""
|
||||
result_set = await session.run(main_query)
|
||||
result_set = await session.run(
|
||||
main_query, {"max_nodes": MAX_GRAPH_NODES}
|
||||
)
|
||||
|
||||
record = await result_set.single()
|
||||
|
||||
if record:
|
||||
|
@@ -24,6 +24,8 @@ from .shared_storage import (
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -233,7 +235,12 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
@@ -265,22 +272,51 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
logger.warning(f"No nodes found with label {node_label}")
|
||||
return result
|
||||
|
||||
# Get subgraph using ego_graph
|
||||
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
|
||||
# Get subgraph using ego_graph from all matching nodes
|
||||
combined_subgraph = nx.Graph()
|
||||
for start_node in nodes_to_explore:
|
||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||
subgraph = combined_subgraph
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
max_graph_nodes = 500
|
||||
if len(subgraph.nodes()) > max_graph_nodes:
|
||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
|
||||
node_degrees = dict(subgraph.degree())
|
||||
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
||||
:max_graph_nodes
|
||||
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
|
||||
if node_label != "*" and nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(subgraph.neighbors(start_node))
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
|
||||
def priority_key(node_item):
|
||||
node, degree = node_item
|
||||
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
||||
if node in start_nodes:
|
||||
priority = 2
|
||||
elif node in direct_connected_nodes:
|
||||
priority = 1
|
||||
else:
|
||||
priority = 0
|
||||
return (priority, degree)
|
||||
|
||||
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
|
||||
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
|
||||
:MAX_GRAPH_NODES
|
||||
]
|
||||
top_node_ids = [node[0] for node in top_nodes]
|
||||
# Create new subgraph with only top nodes
|
||||
# Create new subgraph and keep nodes only with most degree
|
||||
subgraph = subgraph.subgraph(top_node_ids)
|
||||
logger.info(
|
||||
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
|
||||
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
|
||||
)
|
||||
|
||||
# Add nodes to result
|
||||
@@ -320,7 +356,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="DIRECTED",
|
||||
type="RELATED",
|
||||
source=str(source),
|
||||
target=str(target),
|
||||
properties=edge_data,
|
||||
|
@@ -395,6 +395,7 @@ class LightRAG:
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
@@ -968,17 +969,21 @@ class LightRAG:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
||||
def insert_custom_kg(
|
||||
self, custom_kg: dict[str, Any], full_doc_id: str = None
|
||||
) -> None:
|
||||
loop = always_get_an_event_loop()
|
||||
loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||
loop.run_until_complete(self.ainsert_custom_kg(custom_kg, full_doc_id))
|
||||
|
||||
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
||||
async def ainsert_custom_kg(
|
||||
self, custom_kg: dict[str, Any], full_doc_id: str = None
|
||||
) -> None:
|
||||
update_storage = False
|
||||
try:
|
||||
# Insert chunks into vector storage
|
||||
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||
chunk_to_source_map: dict[str, str] = {}
|
||||
for chunk_data in custom_kg.get("chunks", {}):
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
chunk_content = self.clean_text(chunk_data["content"])
|
||||
source_id = chunk_data["source_id"]
|
||||
tokens = len(
|
||||
@@ -998,7 +1003,9 @@ class LightRAG:
|
||||
"source_id": source_id,
|
||||
"tokens": tokens,
|
||||
"chunk_order_index": chunk_order_index,
|
||||
"full_doc_id": source_id,
|
||||
"full_doc_id": full_doc_id
|
||||
if full_doc_id is not None
|
||||
else source_id,
|
||||
"status": DocStatus.PROCESSED,
|
||||
}
|
||||
all_chunks_data[chunk_id] = chunk_entry
|
||||
@@ -1006,9 +1013,10 @@ class LightRAG:
|
||||
update_storage = True
|
||||
|
||||
if all_chunks_data:
|
||||
await self.chunks_vdb.upsert(all_chunks_data)
|
||||
if all_chunks_data:
|
||||
await self.text_chunks.upsert(all_chunks_data)
|
||||
await asyncio.gather(
|
||||
self.chunks_vdb.upsert(all_chunks_data),
|
||||
self.text_chunks.upsert(all_chunks_data),
|
||||
)
|
||||
|
||||
# Insert entities into knowledge graph
|
||||
all_entities_data: list[dict[str, str]] = []
|
||||
@@ -1016,7 +1024,6 @@ class LightRAG:
|
||||
entity_name = entity_data["entity_name"]
|
||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||
description = entity_data.get("description", "No description provided")
|
||||
# source_id = entity_data["source_id"]
|
||||
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
|
||||
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
||||
|
||||
@@ -1048,7 +1055,6 @@ class LightRAG:
|
||||
description = relationship_data["description"]
|
||||
keywords = relationship_data["keywords"]
|
||||
weight = relationship_data.get("weight", 1.0)
|
||||
# source_id = relationship_data["source_id"]
|
||||
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
|
||||
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
||||
|
||||
@@ -1088,34 +1094,43 @@ class LightRAG:
|
||||
"tgt_id": tgt_id,
|
||||
"description": description,
|
||||
"keywords": keywords,
|
||||
"source_id": source_id,
|
||||
"weight": weight,
|
||||
}
|
||||
all_relationships_data.append(edge_data)
|
||||
update_storage = True
|
||||
|
||||
# Insert entities into vector storage if needed
|
||||
# Insert entities into vector storage with consistent format
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"content": dp["entity_name"] + "\n" + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
"source_id": dp["source_id"],
|
||||
"description": dp["description"],
|
||||
"entity_type": dp["entity_type"],
|
||||
}
|
||||
for dp in all_entities_data
|
||||
}
|
||||
await self.entities_vdb.upsert(data_for_vdb)
|
||||
|
||||
# Insert relationships into vector storage if needed
|
||||
# Insert relationships into vector storage with consistent format
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
"source_id": dp["source_id"],
|
||||
"content": f"{dp['keywords']}\t{dp['src_id']}\n{dp['tgt_id']}\n{dp['description']}",
|
||||
"keywords": dp["keywords"],
|
||||
"description": dp["description"],
|
||||
"weight": dp["weight"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
await self.relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ainsert_custom_kg: {e}")
|
||||
raise
|
||||
finally:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
@@ -1160,7 +1175,7 @@ class LightRAG:
|
||||
"""
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
response = await kg_query(
|
||||
query,
|
||||
query.strip(),
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
@@ -1181,7 +1196,7 @@ class LightRAG:
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
response = await naive_query(
|
||||
query,
|
||||
query.strip(),
|
||||
self.chunks_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
@@ -1200,7 +1215,7 @@ class LightRAG:
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
response = await mix_kg_vector_query(
|
||||
query,
|
||||
query.strip(),
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
@@ -1431,17 +1446,19 @@ class LightRAG:
|
||||
# 3. Before deleting, check the related entities and relationships for these chunks
|
||||
for chunk_id in chunk_ids:
|
||||
# Check entities
|
||||
entities_storage = await self.entities_vdb.client_storage
|
||||
entities = [
|
||||
dp
|
||||
for dp in self.entities_vdb.client_storage["data"]
|
||||
for dp in entities_storage["data"]
|
||||
if chunk_id in dp.get("source_id")
|
||||
]
|
||||
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
|
||||
|
||||
# Check relationships
|
||||
relationships_storage = await self.relationships_vdb.client_storage
|
||||
relations = [
|
||||
dp
|
||||
for dp in self.relationships_vdb.client_storage["data"]
|
||||
for dp in relationships_storage["data"]
|
||||
if chunk_id in dp.get("source_id")
|
||||
]
|
||||
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
|
||||
@@ -1505,7 +1522,9 @@ class LightRAG:
|
||||
for entity in entities_to_delete:
|
||||
await self.entities_vdb.delete_entity(entity)
|
||||
logger.debug(f"Deleted entity {entity} from vector DB")
|
||||
self.chunk_entity_relation_graph.remove_nodes(list(entities_to_delete))
|
||||
await self.chunk_entity_relation_graph.remove_nodes(
|
||||
list(entities_to_delete)
|
||||
)
|
||||
logger.debug(f"Deleted {len(entities_to_delete)} entities from graph")
|
||||
|
||||
# Update entities
|
||||
@@ -1524,7 +1543,7 @@ class LightRAG:
|
||||
rel_id_1 = compute_mdhash_id(tgt + src, prefix="rel-")
|
||||
await self.relationships_vdb.delete([rel_id_0, rel_id_1])
|
||||
logger.debug(f"Deleted relationship {src}-{tgt} from vector DB")
|
||||
self.chunk_entity_relation_graph.remove_edges(
|
||||
await self.chunk_entity_relation_graph.remove_edges(
|
||||
list(relationships_to_delete)
|
||||
)
|
||||
logger.debug(
|
||||
@@ -1555,9 +1574,10 @@ class LightRAG:
|
||||
|
||||
async def process_data(data_type, vdb, chunk_id):
|
||||
# Check data (entities or relationships)
|
||||
storage = await vdb.client_storage
|
||||
data_with_chunk = [
|
||||
dp
|
||||
for dp in vdb.client_storage["data"]
|
||||
for dp in storage["data"]
|
||||
if chunk_id in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
]
|
||||
|
||||
@@ -1763,3 +1783,461 @@ class LightRAG:
|
||||
def clear_cache(self, modes: list[str] | None = None) -> None:
|
||||
"""Synchronous version of aclear_cache."""
|
||||
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
||||
|
||||
async def aedit_entity(
|
||||
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously edit entity information.
|
||||
|
||||
Updates entity information in the knowledge graph and re-embeds the entity in the vector database.
|
||||
|
||||
Args:
|
||||
entity_name: Name of the entity to edit
|
||||
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"}
|
||||
allow_rename: Whether to allow entity renaming, defaults to True
|
||||
|
||||
Returns:
|
||||
Dictionary containing updated entity information
|
||||
"""
|
||||
try:
|
||||
# 1. Get current entity information
|
||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||
if not node_data:
|
||||
raise ValueError(f"Entity '{entity_name}' does not exist")
|
||||
|
||||
# Check if entity is being renamed
|
||||
new_entity_name = updated_data.get("entity_name", entity_name)
|
||||
is_renaming = new_entity_name != entity_name
|
||||
|
||||
# If renaming, check if new name already exists
|
||||
if is_renaming:
|
||||
if not allow_rename:
|
||||
raise ValueError(
|
||||
"Entity renaming is not allowed. Set allow_rename=True to enable this feature"
|
||||
)
|
||||
|
||||
existing_node = await self.chunk_entity_relation_graph.get_node(
|
||||
new_entity_name
|
||||
)
|
||||
if existing_node:
|
||||
raise ValueError(
|
||||
f"Entity name '{new_entity_name}' already exists, cannot rename"
|
||||
)
|
||||
|
||||
# 2. Update entity information in the graph
|
||||
new_node_data = {**node_data, **updated_data}
|
||||
if "entity_name" in new_node_data:
|
||||
del new_node_data[
|
||||
"entity_name"
|
||||
] # Node data should not contain entity_name field
|
||||
|
||||
# If renaming entity
|
||||
if is_renaming:
|
||||
logger.info(f"Renaming entity '{entity_name}' to '{new_entity_name}'")
|
||||
|
||||
# Create new entity
|
||||
await self.chunk_entity_relation_graph.upsert_node(
|
||||
new_entity_name, new_node_data
|
||||
)
|
||||
|
||||
# Get all edges related to the original entity
|
||||
edges = await self.chunk_entity_relation_graph.get_node_edges(
|
||||
entity_name
|
||||
)
|
||||
if edges:
|
||||
# Recreate edges for the new entity
|
||||
for source, target in edges:
|
||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
||||
source, target
|
||||
)
|
||||
if edge_data:
|
||||
if source == entity_name:
|
||||
await self.chunk_entity_relation_graph.upsert_edge(
|
||||
new_entity_name, target, edge_data
|
||||
)
|
||||
else: # target == entity_name
|
||||
await self.chunk_entity_relation_graph.upsert_edge(
|
||||
source, new_entity_name, edge_data
|
||||
)
|
||||
|
||||
# Delete old entity
|
||||
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
||||
|
||||
# Delete old entity record from vector database
|
||||
old_entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
await self.entities_vdb.delete([old_entity_id])
|
||||
|
||||
# Update working entity name to new name
|
||||
entity_name = new_entity_name
|
||||
else:
|
||||
# If not renaming, directly update node data
|
||||
await self.chunk_entity_relation_graph.upsert_node(
|
||||
entity_name, new_node_data
|
||||
)
|
||||
|
||||
# 3. Recalculate entity's vector representation and update vector database
|
||||
description = new_node_data.get("description", "")
|
||||
source_id = new_node_data.get("source_id", "")
|
||||
entity_type = new_node_data.get("entity_type", "")
|
||||
content = entity_name + "\n" + description
|
||||
|
||||
# Calculate entity ID
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
|
||||
# Prepare data for vector database update
|
||||
entity_data = {
|
||||
entity_id: {
|
||||
"content": content,
|
||||
"entity_name": entity_name,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": entity_type,
|
||||
}
|
||||
}
|
||||
|
||||
# Update vector database
|
||||
await self.entities_vdb.upsert(entity_data)
|
||||
|
||||
# 4. Save changes
|
||||
await self._edit_entity_done()
|
||||
|
||||
logger.info(f"Entity '{entity_name}' successfully updated")
|
||||
return await self.get_entity_info(entity_name, include_vector_data=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while editing entity '{entity_name}': {e}")
|
||||
raise
|
||||
|
||||
def edit_entity(
|
||||
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""Synchronously edit entity information.
|
||||
|
||||
Updates entity information in the knowledge graph and re-embeds the entity in the vector database.
|
||||
|
||||
Args:
|
||||
entity_name: Name of the entity to edit
|
||||
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"}
|
||||
allow_rename: Whether to allow entity renaming, defaults to True
|
||||
|
||||
Returns:
|
||||
Dictionary containing updated entity information
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.aedit_entity(entity_name, updated_data, allow_rename)
|
||||
)
|
||||
|
||||
async def _edit_entity_done(self) -> None:
|
||||
"""Callback after entity editing is complete, ensures updates are persisted"""
|
||||
await asyncio.gather(
|
||||
*[
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
self.entities_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
async def aedit_relation(
|
||||
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously edit relation information.
|
||||
|
||||
Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database.
|
||||
|
||||
Args:
|
||||
source_entity: Name of the source entity
|
||||
target_entity: Name of the target entity
|
||||
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "keywords": "new keywords"}
|
||||
|
||||
Returns:
|
||||
Dictionary containing updated relation information
|
||||
"""
|
||||
try:
|
||||
# 1. Get current relation information
|
||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
||||
source_entity, target_entity
|
||||
)
|
||||
if not edge_data:
|
||||
raise ValueError(
|
||||
f"Relation from '{source_entity}' to '{target_entity}' does not exist"
|
||||
)
|
||||
|
||||
# 2. Update relation information in the graph
|
||||
new_edge_data = {**edge_data, **updated_data}
|
||||
await self.chunk_entity_relation_graph.upsert_edge(
|
||||
source_entity, target_entity, new_edge_data
|
||||
)
|
||||
|
||||
# 3. Recalculate relation's vector representation and update vector database
|
||||
description = new_edge_data.get("description", "")
|
||||
keywords = new_edge_data.get("keywords", "")
|
||||
source_id = new_edge_data.get("source_id", "")
|
||||
weight = float(new_edge_data.get("weight", 1.0))
|
||||
|
||||
# Create content for embedding
|
||||
content = f"{keywords}\t{source_entity}\n{target_entity}\n{description}"
|
||||
|
||||
# Calculate relation ID
|
||||
relation_id = compute_mdhash_id(
|
||||
source_entity + target_entity, prefix="rel-"
|
||||
)
|
||||
|
||||
# Prepare data for vector database update
|
||||
relation_data = {
|
||||
relation_id: {
|
||||
"content": content,
|
||||
"src_id": source_entity,
|
||||
"tgt_id": target_entity,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"keywords": keywords,
|
||||
"weight": weight,
|
||||
}
|
||||
}
|
||||
|
||||
# Update vector database
|
||||
await self.relationships_vdb.upsert(relation_data)
|
||||
|
||||
# 4. Save changes
|
||||
await self._edit_relation_done()
|
||||
|
||||
logger.info(
|
||||
f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
|
||||
)
|
||||
return await self.get_relation_info(
|
||||
source_entity, target_entity, include_vector_data=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while editing relation from '{source_entity}' to '{target_entity}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def edit_relation(
|
||||
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Synchronously edit relation information.
|
||||
|
||||
Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database.
|
||||
|
||||
Args:
|
||||
source_entity: Name of the source entity
|
||||
target_entity: Name of the target entity
|
||||
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "keywords": "keywords"}
|
||||
|
||||
Returns:
|
||||
Dictionary containing updated relation information
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.aedit_relation(source_entity, target_entity, updated_data)
|
||||
)
|
||||
|
||||
async def _edit_relation_done(self) -> None:
|
||||
"""Callback after relation editing is complete, ensures updates are persisted"""
|
||||
await asyncio.gather(
|
||||
*[
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
self.relationships_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
async def acreate_entity(
|
||||
self, entity_name: str, entity_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously create a new entity.
|
||||
|
||||
Creates a new entity in the knowledge graph and adds it to the vector database.
|
||||
|
||||
Args:
|
||||
entity_name: Name of the new entity
|
||||
entity_data: Dictionary containing entity attributes, e.g. {"description": "description", "entity_type": "type"}
|
||||
|
||||
Returns:
|
||||
Dictionary containing created entity information
|
||||
"""
|
||||
try:
|
||||
# Check if entity already exists
|
||||
existing_node = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||
if existing_node:
|
||||
raise ValueError(f"Entity '{entity_name}' already exists")
|
||||
|
||||
# Prepare node data with defaults if missing
|
||||
node_data = {
|
||||
"entity_type": entity_data.get("entity_type", "UNKNOWN"),
|
||||
"description": entity_data.get("description", ""),
|
||||
"source_id": entity_data.get("source_id", "manual"),
|
||||
}
|
||||
|
||||
# Add entity to knowledge graph
|
||||
await self.chunk_entity_relation_graph.upsert_node(entity_name, node_data)
|
||||
|
||||
# Prepare content for entity
|
||||
description = node_data.get("description", "")
|
||||
source_id = node_data.get("source_id", "")
|
||||
entity_type = node_data.get("entity_type", "")
|
||||
content = entity_name + "\n" + description
|
||||
|
||||
# Calculate entity ID
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
|
||||
# Prepare data for vector database update
|
||||
entity_data_for_vdb = {
|
||||
entity_id: {
|
||||
"content": content,
|
||||
"entity_name": entity_name,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": entity_type,
|
||||
}
|
||||
}
|
||||
|
||||
# Update vector database
|
||||
await self.entities_vdb.upsert(entity_data_for_vdb)
|
||||
|
||||
# Save changes
|
||||
await self._edit_entity_done()
|
||||
|
||||
logger.info(f"Entity '{entity_name}' successfully created")
|
||||
return await self.get_entity_info(entity_name, include_vector_data=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while creating entity '{entity_name}': {e}")
|
||||
raise
|
||||
|
||||
def create_entity(
|
||||
self, entity_name: str, entity_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Synchronously create a new entity.
|
||||
|
||||
Creates a new entity in the knowledge graph and adds it to the vector database.
|
||||
|
||||
Args:
|
||||
entity_name: Name of the new entity
|
||||
entity_data: Dictionary containing entity attributes, e.g. {"description": "description", "entity_type": "type"}
|
||||
|
||||
Returns:
|
||||
Dictionary containing created entity information
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.acreate_entity(entity_name, entity_data))
|
||||
|
||||
async def acreate_relation(
|
||||
self, source_entity: str, target_entity: str, relation_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously create a new relation between entities.
|
||||
|
||||
Creates a new relation (edge) in the knowledge graph and adds it to the vector database.
|
||||
|
||||
Args:
|
||||
source_entity: Name of the source entity
|
||||
target_entity: Name of the target entity
|
||||
relation_data: Dictionary containing relation attributes, e.g. {"description": "description", "keywords": "keywords"}
|
||||
|
||||
Returns:
|
||||
Dictionary containing created relation information
|
||||
"""
|
||||
try:
|
||||
# Check if both entities exist
|
||||
source_exists = await self.chunk_entity_relation_graph.has_node(
|
||||
source_entity
|
||||
)
|
||||
target_exists = await self.chunk_entity_relation_graph.has_node(
|
||||
target_entity
|
||||
)
|
||||
|
||||
if not source_exists:
|
||||
raise ValueError(f"Source entity '{source_entity}' does not exist")
|
||||
if not target_exists:
|
||||
raise ValueError(f"Target entity '{target_entity}' does not exist")
|
||||
|
||||
# Check if relation already exists
|
||||
existing_edge = await self.chunk_entity_relation_graph.get_edge(
|
||||
source_entity, target_entity
|
||||
)
|
||||
if existing_edge:
|
||||
raise ValueError(
|
||||
f"Relation from '{source_entity}' to '{target_entity}' already exists"
|
||||
)
|
||||
|
||||
# Prepare edge data with defaults if missing
|
||||
edge_data = {
|
||||
"description": relation_data.get("description", ""),
|
||||
"keywords": relation_data.get("keywords", ""),
|
||||
"source_id": relation_data.get("source_id", "manual"),
|
||||
"weight": float(relation_data.get("weight", 1.0)),
|
||||
}
|
||||
|
||||
# Add relation to knowledge graph
|
||||
await self.chunk_entity_relation_graph.upsert_edge(
|
||||
source_entity, target_entity, edge_data
|
||||
)
|
||||
|
||||
# Prepare content for embedding
|
||||
description = edge_data.get("description", "")
|
||||
keywords = edge_data.get("keywords", "")
|
||||
source_id = edge_data.get("source_id", "")
|
||||
weight = edge_data.get("weight", 1.0)
|
||||
|
||||
# Create content for embedding
|
||||
content = f"{keywords}\t{source_entity}\n{target_entity}\n{description}"
|
||||
|
||||
# Calculate relation ID
|
||||
relation_id = compute_mdhash_id(
|
||||
source_entity + target_entity, prefix="rel-"
|
||||
)
|
||||
|
||||
# Prepare data for vector database update
|
||||
relation_data_for_vdb = {
|
||||
relation_id: {
|
||||
"content": content,
|
||||
"src_id": source_entity,
|
||||
"tgt_id": target_entity,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"keywords": keywords,
|
||||
"weight": weight,
|
||||
}
|
||||
}
|
||||
|
||||
# Update vector database
|
||||
await self.relationships_vdb.upsert(relation_data_for_vdb)
|
||||
|
||||
# Save changes
|
||||
await self._edit_relation_done()
|
||||
|
||||
logger.info(
|
||||
f"Relation from '{source_entity}' to '{target_entity}' successfully created"
|
||||
)
|
||||
return await self.get_relation_info(
|
||||
source_entity, target_entity, include_vector_data=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while creating relation from '{source_entity}' to '{target_entity}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def create_relation(
|
||||
self, source_entity: str, target_entity: str, relation_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Synchronously create a new relation between entities.
|
||||
|
||||
Creates a new relation (edge) in the knowledge graph and adds it to the vector database.
|
||||
|
||||
Args:
|
||||
source_entity: Name of the source entity
|
||||
target_entity: Name of the target entity
|
||||
relation_data: Dictionary containing relation attributes, e.g. {"description": "description", "keywords": "keywords"}
|
||||
|
||||
Returns:
|
||||
Dictionary containing created relation information
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.acreate_relation(source_entity, target_entity, relation_data)
|
||||
)
|
||||
|
Reference in New Issue
Block a user