use KnowledgeGraph typed dict for graph API response

This commit is contained in:
ArnoChen
2025-02-13 17:32:51 +08:00
parent e4562c761c
commit c674905a98
5 changed files with 54 additions and 30 deletions

View File

@@ -1424,8 +1424,8 @@ def create_app(args):
# query all graph # query all graph
@app.get("/graphs") @app.get("/graphs")
async def get_graphs(label: str): async def get_knowledge_graph(label: str):
return await rag.get_graphs(nodel_label=label, max_depth=100) return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
# Add Ollama API routes # Add Ollama API routes
ollama_api = OllamaAPI(rag) ollama_api = OllamaAPI(rag)

View File

@@ -13,6 +13,7 @@ from typing import (
import numpy as np import numpy as np
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph
class TextChunkSchema(TypedDict): class TextChunkSchema(TypedDict):
@@ -175,7 +176,7 @@ class BaseGraphStorage(StorageNameSpace):
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> dict[str, list[dict]]: ) -> KnowledgeGraph:
raise NotImplementedError raise NotImplementedError

View File

@@ -25,6 +25,7 @@ from tenacity import (
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@dataclass @dataclass
@@ -44,7 +45,8 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"] URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"] USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"] PASSWORD = os.environ["NEO4J_PASSWORD"]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) MAX_CONNECTION_POOL_SIZE = os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
) )
@@ -74,19 +76,22 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise e raise e
except neo4jExceptions.AuthError as e: except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}") logger.error(
f"Authentication failed for {database} at {URI}")
raise e raise e
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound": if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info( logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize() f"{database} at {URI} not found. Try to create specified database.".capitalize(
)
) )
try: try:
with _sync_driver.session() as session: with _sync_driver.session() as session:
session.run( session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS" f"CREATE DATABASE `{database}` IF NOT EXISTS"
) )
logger.info(f"{database} at {URI} created".capitalize()) logger.info(
f"{database} at {URI} created".capitalize())
connected = True connected = True
except ( except (
neo4jExceptions.ClientError, neo4jExceptions.ClientError,
@@ -103,7 +108,8 @@ class Neo4JStorage(BaseGraphStorage):
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
) )
if database is None: if database is None:
logger.error(f"Failed to create {database} at {URI}") logger.error(
f"Failed to create {database} at {URI}")
raise e raise e
if connected: if connected:
@@ -365,7 +371,7 @@ class Neo4JStorage(BaseGraphStorage):
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> Dict[str, List[Dict]]: ) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
@@ -376,7 +382,7 @@ class Neo4JStorage(BaseGraphStorage):
4. Add depth control 4. Add depth control
""" """
label = node_label.strip('"') label = node_label.strip('"')
result = {"nodes": [], "edges": []} result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
@@ -395,7 +401,8 @@ class Neo4JStorage(BaseGraphStorage):
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query) validate_result = await session.run(validate_query)
if not await validate_result.single(): if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!") logger.warning(
f"Starting node {label} does not exist!")
return result return result
# Optimized query (including direction handling and self-loops) # Optimized query (including direction handling and self-loops)
@@ -420,11 +427,11 @@ class Neo4JStorage(BaseGraphStorage):
# Use node ID + label combination as unique identifier # Use node ID + label combination as unique identifier
node_id = node.id node_id = node.id
if node_id not in seen_nodes: if node_id not in seen_nodes:
node_data = {} result.nodes.append(KnowledgeGraphNode(
node_data["labels"] = list(node.labels) # Keep all labels id=f"{node_id}",
node_data["id"] = f"{node_id}" labels=list(node.labels),
node_data["properties"] = dict(node) properties=dict(node),
result["nodes"].append(node_data) ))
seen_nodes.add(node_id) seen_nodes.add(node_id)
# Handle relationships (including direction information) # Handle relationships (including direction information)
@@ -433,21 +440,17 @@ class Neo4JStorage(BaseGraphStorage):
if edge_id not in seen_edges: if edge_id not in seen_edges:
start = rel.start_node start = rel.start_node
end = rel.end_node end = rel.end_node
edge_data = {} result.edges.append(KnowledgeGraphEdge(
edge_data.update( id=f"{edge_id}",
{ type=rel.type,
"source": f"{start.id}", source=f"{start.id}",
"target": f"{end.id}", target=f"{end.id}",
"type": rel.type, properties=dict(rel),
"id": f"{edge_id}", ))
"properties": dict(rel),
}
)
result["edges"].append(edge_data)
seen_edges.add(edge_id) seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:

View File

@@ -34,6 +34,7 @@ from .utils import (
logger, logger,
set_logger, set_logger,
) )
from .types import KnowledgeGraph
STORAGES = { STORAGES = {
"NetworkXStorage": ".kg.networkx_impl", "NetworkXStorage": ".kg.networkx_impl",
@@ -385,7 +386,7 @@ class LightRAG:
text = await self.chunk_entity_relation_graph.get_all_labels() text = await self.chunk_entity_relation_graph.get_all_labels()
return text return text
async def get_graphs(self, nodel_label: str, max_depth: int): async def get_knowledge_graph(self, nodel_label: str, max_depth: int) -> KnowledgeGraph:
return await self.chunk_entity_relation_graph.get_knowledge_graph( return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=nodel_label, max_depth=max_depth node_label=nodel_label, max_depth=max_depth
) )

View File

@@ -1,7 +1,26 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import List, Dict, Any
class GPTKeywordExtractionFormat(BaseModel): class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str] high_level_keywords: List[str]
low_level_keywords: List[str] low_level_keywords: List[str]
class KnowledgeGraphNode(BaseModel):
id: str
labels: List[str]
properties: Dict[str, Any] # anything else goes here
class KnowledgeGraphEdge(BaseModel):
id: str
type: str
source: str # id of source node
target: str # id of target node
properties: Dict[str, Any] # anything else goes here
class KnowledgeGraph(BaseModel):
nodes: List[KnowledgeGraphNode] = []
edges: List[KnowledgeGraphEdge] = []