use KnowledgeGraph typed dict for graph API response
This commit is contained in:
@@ -1424,8 +1424,8 @@ def create_app(args):
|
|||||||
|
|
||||||
# query all graph
|
# query all graph
|
||||||
@app.get("/graphs")
|
@app.get("/graphs")
|
||||||
async def get_graphs(label: str):
|
async def get_knowledge_graph(label: str):
|
||||||
return await rag.get_graphs(nodel_label=label, max_depth=100)
|
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
|
||||||
|
|
||||||
# Add Ollama API routes
|
# Add Ollama API routes
|
||||||
ollama_api = OllamaAPI(rag)
|
ollama_api = OllamaAPI(rag)
|
||||||
|
@@ -13,6 +13,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
|
from .types import KnowledgeGraph
|
||||||
|
|
||||||
|
|
||||||
class TextChunkSchema(TypedDict):
|
class TextChunkSchema(TypedDict):
|
||||||
@@ -175,7 +176,7 @@ class BaseGraphStorage(StorageNameSpace):
|
|||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self, node_label: str, max_depth: int = 5
|
||||||
) -> dict[str, list[dict]]:
|
) -> KnowledgeGraph:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@@ -25,6 +25,7 @@ from tenacity import (
|
|||||||
|
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -44,7 +45,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
URI = os.environ["NEO4J_URI"]
|
URI = os.environ["NEO4J_URI"]
|
||||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||||
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
|
MAX_CONNECTION_POOL_SIZE = os.environ.get(
|
||||||
|
"NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
|
||||||
DATABASE = os.environ.get(
|
DATABASE = os.environ.get(
|
||||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||||
)
|
)
|
||||||
@@ -74,19 +76,22 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
except neo4jExceptions.AuthError as e:
|
except neo4jExceptions.AuthError as e:
|
||||||
logger.error(f"Authentication failed for {database} at {URI}")
|
logger.error(
|
||||||
|
f"Authentication failed for {database} at {URI}")
|
||||||
raise e
|
raise e
|
||||||
except neo4jExceptions.ClientError as e:
|
except neo4jExceptions.ClientError as e:
|
||||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
f"{database} at {URI} not found. Try to create specified database.".capitalize(
|
||||||
|
)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with _sync_driver.session() as session:
|
with _sync_driver.session() as session:
|
||||||
session.run(
|
session.run(
|
||||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||||
)
|
)
|
||||||
logger.info(f"{database} at {URI} created".capitalize())
|
logger.info(
|
||||||
|
f"{database} at {URI} created".capitalize())
|
||||||
connected = True
|
connected = True
|
||||||
except (
|
except (
|
||||||
neo4jExceptions.ClientError,
|
neo4jExceptions.ClientError,
|
||||||
@@ -103,7 +108,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||||
)
|
)
|
||||||
if database is None:
|
if database is None:
|
||||||
logger.error(f"Failed to create {database} at {URI}")
|
logger.error(
|
||||||
|
f"Failed to create {database} at {URI}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if connected:
|
if connected:
|
||||||
@@ -365,7 +371,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self, node_label: str, max_depth: int = 5
|
||||||
) -> Dict[str, List[Dict]]:
|
) -> KnowledgeGraph:
|
||||||
"""
|
"""
|
||||||
Get complete connected subgraph for specified node (including the starting node itself)
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
@@ -376,7 +382,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
4. Add depth control
|
4. Add depth control
|
||||||
"""
|
"""
|
||||||
label = node_label.strip('"')
|
label = node_label.strip('"')
|
||||||
result = {"nodes": [], "edges": []}
|
result = KnowledgeGraph()
|
||||||
seen_nodes = set()
|
seen_nodes = set()
|
||||||
seen_edges = set()
|
seen_edges = set()
|
||||||
|
|
||||||
@@ -395,7 +401,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
||||||
validate_result = await session.run(validate_query)
|
validate_result = await session.run(validate_query)
|
||||||
if not await validate_result.single():
|
if not await validate_result.single():
|
||||||
logger.warning(f"Starting node {label} does not exist!")
|
logger.warning(
|
||||||
|
f"Starting node {label} does not exist!")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Optimized query (including direction handling and self-loops)
|
# Optimized query (including direction handling and self-loops)
|
||||||
@@ -420,11 +427,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
# Use node ID + label combination as unique identifier
|
# Use node ID + label combination as unique identifier
|
||||||
node_id = node.id
|
node_id = node.id
|
||||||
if node_id not in seen_nodes:
|
if node_id not in seen_nodes:
|
||||||
node_data = {}
|
result.nodes.append(KnowledgeGraphNode(
|
||||||
node_data["labels"] = list(node.labels) # Keep all labels
|
id=f"{node_id}",
|
||||||
node_data["id"] = f"{node_id}"
|
labels=list(node.labels),
|
||||||
node_data["properties"] = dict(node)
|
properties=dict(node),
|
||||||
result["nodes"].append(node_data)
|
))
|
||||||
seen_nodes.add(node_id)
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
# Handle relationships (including direction information)
|
# Handle relationships (including direction information)
|
||||||
@@ -433,21 +440,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
if edge_id not in seen_edges:
|
if edge_id not in seen_edges:
|
||||||
start = rel.start_node
|
start = rel.start_node
|
||||||
end = rel.end_node
|
end = rel.end_node
|
||||||
edge_data = {}
|
result.edges.append(KnowledgeGraphEdge(
|
||||||
edge_data.update(
|
id=f"{edge_id}",
|
||||||
{
|
type=rel.type,
|
||||||
"source": f"{start.id}",
|
source=f"{start.id}",
|
||||||
"target": f"{end.id}",
|
target=f"{end.id}",
|
||||||
"type": rel.type,
|
properties=dict(rel),
|
||||||
"id": f"{edge_id}",
|
))
|
||||||
"properties": dict(rel),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result["edges"].append(edge_data)
|
|
||||||
seen_edges.add(edge_id)
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}"
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except neo4jExceptions.ClientError as e:
|
except neo4jExceptions.ClientError as e:
|
||||||
|
@@ -34,6 +34,7 @@ from .utils import (
|
|||||||
logger,
|
logger,
|
||||||
set_logger,
|
set_logger,
|
||||||
)
|
)
|
||||||
|
from .types import KnowledgeGraph
|
||||||
|
|
||||||
STORAGES = {
|
STORAGES = {
|
||||||
"NetworkXStorage": ".kg.networkx_impl",
|
"NetworkXStorage": ".kg.networkx_impl",
|
||||||
@@ -385,7 +386,7 @@ class LightRAG:
|
|||||||
text = await self.chunk_entity_relation_graph.get_all_labels()
|
text = await self.chunk_entity_relation_graph.get_all_labels()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
async def get_graphs(self, nodel_label: str, max_depth: int):
|
async def get_knowledge_graph(self, nodel_label: str, max_depth: int) -> KnowledgeGraph:
|
||||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||||
node_label=nodel_label, max_depth=max_depth
|
node_label=nodel_label, max_depth=max_depth
|
||||||
)
|
)
|
||||||
|
@@ -1,7 +1,26 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class GPTKeywordExtractionFormat(BaseModel):
|
class GPTKeywordExtractionFormat(BaseModel):
|
||||||
high_level_keywords: List[str]
|
high_level_keywords: List[str]
|
||||||
low_level_keywords: List[str]
|
low_level_keywords: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeGraphNode(BaseModel):
|
||||||
|
id: str
|
||||||
|
labels: List[str]
|
||||||
|
properties: Dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeGraphEdge(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
source: str # id of source node
|
||||||
|
target: str # id of target node
|
||||||
|
properties: Dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeGraph(BaseModel):
|
||||||
|
nodes: List[KnowledgeGraphNode] = []
|
||||||
|
edges: List[KnowledgeGraphEdge] = []
|
||||||
|
Reference in New Issue
Block a user