Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -70,7 +70,7 @@ def main():
|
|||||||
),
|
),
|
||||||
vector_storage="FaissVectorDBStorage",
|
vector_storage="FaissVectorDBStorage",
|
||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": 0.3 # Your desired threshold
|
"cosine_better_than_threshold": 0.2 # Your desired threshold
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||||
|
|
||||||
__version__ = "1.1.7"
|
__version__ = "1.1.11"
|
||||||
__author__ = "Zirui Guo"
|
__author__ = "Zirui Guo"
|
||||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||||
|
@@ -1748,7 +1748,16 @@ def create_app(args):
|
|||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# query all graph labels
|
||||||
|
@app.get("/graph/label/list")
|
||||||
|
async def get_graph_labels():
|
||||||
|
return await rag.get_graph_labels()
|
||||||
|
|
||||||
# query all graph
|
# query all graph
|
||||||
|
@app.get("/graphs")
|
||||||
|
async def get_knowledge_graph(label: str):
|
||||||
|
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
|
||||||
|
|
||||||
# Add Ollama API routes
|
# Add Ollama API routes
|
||||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||||
app.include_router(ollama_api.router, prefix="/api")
|
app.include_router(ollama_api.router, prefix="/api")
|
||||||
|
@@ -13,6 +13,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
|
from .types import KnowledgeGraph
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@@ -197,6 +198,16 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
"""Get all labels in the graph."""
|
"""Get all labels in the graph."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
"""Get a knowledge graph of a node."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||||
|
|
||||||
|
|
||||||
class DocStatus(str, Enum):
|
class DocStatus(str, Enum):
|
||||||
"""Document processing status"""
|
"""Document processing status"""
|
||||||
|
@@ -8,6 +8,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
|
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
@@ -59,6 +60,10 @@ class AGEQueryException(Exception):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class AGEStorage(BaseGraphStorage):
|
class AGEStorage(BaseGraphStorage):
|
||||||
|
@staticmethod
|
||||||
|
def load_nx_graph(file_name):
|
||||||
|
print("no preloading of graph with AGE in production")
|
||||||
|
|
||||||
def __init__(self, namespace, global_config, embedding_func):
|
def __init__(self, namespace, global_config, embedding_func):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
@@ -615,6 +620,14 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# AGES handles persistence automatically
|
# AGES handles persistence automatically
|
||||||
pass
|
pass
|
||||||
|
@@ -16,6 +16,7 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
@@ -401,3 +402,11 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
self, algorithm: str
|
self, algorithm: str
|
||||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -16,6 +16,7 @@ from ..base import (
|
|||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
if not pm.is_installed("pymongo"):
|
||||||
@@ -598,6 +599,197 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# QUERY
|
# QUERY
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all existing node _id in the database
|
||||||
|
Returns:
|
||||||
|
[id1, id2, ...] # Alphabetically sorted id list
|
||||||
|
"""
|
||||||
|
# Use MongoDB's distinct and aggregation to get all unique labels
|
||||||
|
pipeline = [
|
||||||
|
{"$group": {"_id": "$_id"}}, # Group by _id
|
||||||
|
{"$sort": {"_id": 1}}, # Sort alphabetically
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = self.collection.aggregate(pipeline)
|
||||||
|
labels = []
|
||||||
|
async for doc in cursor:
|
||||||
|
labels.append(doc["_id"])
|
||||||
|
return labels
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""
|
||||||
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_label: Label of the nodes to start from
|
||||||
|
max_depth: Maximum depth of traversal (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||||
|
"""
|
||||||
|
label = node_label
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if label == "*":
|
||||||
|
# Get all nodes and edges
|
||||||
|
async for node_doc in self.collection.find({}):
|
||||||
|
node_id = str(node_doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node_doc.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in node_doc.items()
|
||||||
|
if k not in ["_id", "edges"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Process edges
|
||||||
|
for edge in node_doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
else:
|
||||||
|
# Verify if starting node exists
|
||||||
|
start_nodes = self.collection.find({"_id": label})
|
||||||
|
start_nodes_exist = await start_nodes.to_list(length=1)
|
||||||
|
if not start_nodes_exist:
|
||||||
|
logger.warning(f"Starting node with label {label} does not exist!")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Use $graphLookup for traversal
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {"_id": label}
|
||||||
|
}, # Start with nodes having the specified label
|
||||||
|
{
|
||||||
|
"$graphLookup": {
|
||||||
|
"from": self._collection_name,
|
||||||
|
"startWith": "$edges.target",
|
||||||
|
"connectFromField": "edges.target",
|
||||||
|
"connectToField": "_id",
|
||||||
|
"maxDepth": max_depth,
|
||||||
|
"depthField": "depth",
|
||||||
|
"as": "connected_nodes",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
async for doc in self.collection.aggregate(pipeline):
|
||||||
|
# Add the start node
|
||||||
|
node_id = str(doc["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[
|
||||||
|
doc.get(
|
||||||
|
"_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in doc.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"edges",
|
||||||
|
"connected_nodes",
|
||||||
|
"depth",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from start node
|
||||||
|
for edge in doc.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
# Add connected nodes and their edges
|
||||||
|
for connected in doc.get("connected_nodes", []):
|
||||||
|
node_id = str(connected["_id"])
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[connected.get("_id")],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in connected.items()
|
||||||
|
if k not in ["_id", "edges", "depth"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Add edges from connected nodes
|
||||||
|
for edge in connected.get("edges", []):
|
||||||
|
edge_id = f"{node_id}-{edge['target']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relation", ""),
|
||||||
|
source=node_id,
|
||||||
|
target=edge["target"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k not in ["target", "relation"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"MongoDB query failed: {str(e)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# Mongo handles persistence automatically
|
# Mongo handles persistence automatically
|
||||||
|
@@ -17,6 +17,7 @@ from tenacity import (
|
|||||||
|
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
@@ -468,6 +469,99 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
print("Implemented but never called.")
|
print("Implemented but never called.")
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""
|
||||||
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
|
Key fixes:
|
||||||
|
1. Include the starting node itself
|
||||||
|
2. Handle multi-label nodes
|
||||||
|
3. Clarify relationship directions
|
||||||
|
4. Add depth control
|
||||||
|
"""
|
||||||
|
label = node_label.strip('"')
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
|
||||||
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
try:
|
||||||
|
main_query = ""
|
||||||
|
if label == "*":
|
||||||
|
main_query = """
|
||||||
|
MATCH (n)
|
||||||
|
WITH collect(DISTINCT n) AS nodes
|
||||||
|
MATCH ()-[r]-()
|
||||||
|
RETURN nodes, collect(DISTINCT r) AS relationships;
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
# Critical debug step: first verify if starting node exists
|
||||||
|
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
||||||
|
validate_result = await session.run(validate_query)
|
||||||
|
if not await validate_result.single():
|
||||||
|
logger.warning(f"Starting node {label} does not exist!")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Optimized query (including direction handling and self-loops)
|
||||||
|
main_query = f"""
|
||||||
|
MATCH (start:`{label}`)
|
||||||
|
WITH start
|
||||||
|
CALL apoc.path.subgraphAll(start, {{
|
||||||
|
relationshipFilter: '>',
|
||||||
|
minLevel: 0,
|
||||||
|
maxLevel: {max_depth},
|
||||||
|
bfs: true
|
||||||
|
}})
|
||||||
|
YIELD nodes, relationships
|
||||||
|
RETURN nodes, relationships
|
||||||
|
"""
|
||||||
|
result_set = await session.run(main_query)
|
||||||
|
record = await result_set.single()
|
||||||
|
|
||||||
|
if record:
|
||||||
|
# Handle nodes (compatible with multi-label cases)
|
||||||
|
for node in record["nodes"]:
|
||||||
|
# Use node ID + label combination as unique identifier
|
||||||
|
node_id = node.id
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=f"{node_id}",
|
||||||
|
labels=list(node.labels),
|
||||||
|
properties=dict(node),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Handle relationships (including direction information)
|
||||||
|
for rel in record["relationships"]:
|
||||||
|
edge_id = rel.id
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
start = rel.start_node
|
||||||
|
end = rel.end_node
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=f"{edge_id}",
|
||||||
|
type=rel.type,
|
||||||
|
source=f"{start.id}",
|
||||||
|
target=f"{end.id}",
|
||||||
|
properties=dict(rel),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except neo4jExceptions.ClientError as e:
|
||||||
|
logger.error(f"APOC query failed: {str(e)}")
|
||||||
|
return await self._robust_fallback(label, max_depth)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def _robust_fallback(
|
async def _robust_fallback(
|
||||||
self, label: str, max_depth: int
|
self, label: str, max_depth: int
|
||||||
) -> Dict[str, List[Dict]]:
|
) -> Dict[str, List[Dict]]:
|
||||||
@@ -534,6 +628,31 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
await traverse(label, 0)
|
await traverse(label, 0)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all existing node labels in the database
|
||||||
|
Returns:
|
||||||
|
["Person", "Company", ...] # Alphabetically sorted label list
|
||||||
|
"""
|
||||||
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
|
||||||
|
# query = "CALL db.labels() YIELD label RETURN label"
|
||||||
|
|
||||||
|
# Method 2: Query compatible with older versions
|
||||||
|
query = """
|
||||||
|
MATCH (n)
|
||||||
|
WITH DISTINCT labels(n) AS node_labels
|
||||||
|
UNWIND node_labels AS label
|
||||||
|
RETURN DISTINCT label
|
||||||
|
ORDER BY label
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await session.run(query)
|
||||||
|
labels = []
|
||||||
|
async for record in result:
|
||||||
|
labels.append(record["label"])
|
||||||
|
return labels
|
||||||
|
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -5,6 +5,7 @@ from typing import Any, final
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
@@ -16,11 +17,12 @@ import pipmaster as pm
|
|||||||
|
|
||||||
if not pm.is_installed("networkx"):
|
if not pm.is_installed("networkx"):
|
||||||
pm.install("networkx")
|
pm.install("networkx")
|
||||||
|
|
||||||
if not pm.is_installed("graspologic"):
|
if not pm.is_installed("graspologic"):
|
||||||
pm.install("graspologic")
|
pm.install("graspologic")
|
||||||
|
|
||||||
from graspologic import embed
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
from graspologic import embed
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@@ -165,3 +167,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
for source, target in edges:
|
for source, target in edges:
|
||||||
if self._graph.has_edge(source, target):
|
if self._graph.has_edge(source, target):
|
||||||
self._graph.remove_edge(source, target)
|
self._graph.remove_edge(source, target)
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -8,6 +8,7 @@ from typing import Any, Union, final
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -669,6 +670,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
@@ -7,6 +7,7 @@ from typing import Any, Union, final
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
@@ -177,10 +178,12 @@ class PostgreSQLDB:
|
|||||||
asyncpg.exceptions.UniqueViolationError,
|
asyncpg.exceptions.UniqueViolationError,
|
||||||
asyncpg.exceptions.DuplicateTableError,
|
asyncpg.exceptions.DuplicateTableError,
|
||||||
) as e:
|
) as e:
|
||||||
if not upsert:
|
if upsert:
|
||||||
logger.error(f"PostgreSQL, upsert error: {e}")
|
print("Key value duplicate, but upsert succeeded.")
|
||||||
|
else:
|
||||||
|
logger.error(f"Upsert error: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"PostgreSQL database, sql:{sql}, data:{data}, error:{e}")
|
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -1084,6 +1087,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
|
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
|
||||||
|
@@ -5,6 +5,8 @@ from typing import Any, Union, final
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
|
|
||||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
@@ -558,6 +560,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, node_label: str, max_depth: int = 5
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
@@ -47,6 +47,7 @@ from .utils import (
|
|||||||
set_logger,
|
set_logger,
|
||||||
encode_string_by_tiktoken,
|
encode_string_by_tiktoken,
|
||||||
)
|
)
|
||||||
|
from .types import KnowledgeGraph
|
||||||
|
|
||||||
# TODO: TO REMOVE @Yannick
|
# TODO: TO REMOVE @Yannick
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -184,7 +185,7 @@ class LightRAG:
|
|||||||
"""Maximum number of concurrent embedding function calls."""
|
"""Maximum number of concurrent embedding function calls."""
|
||||||
|
|
||||||
embedding_cache_config: dict[str, Any] = field(
|
embedding_cache_config: dict[str, Any] = field(
|
||||||
default={
|
default_factory=lambda: {
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"similarity_threshold": 0.95,
|
"similarity_threshold": 0.95,
|
||||||
"use_llm_check": False,
|
"use_llm_check": False,
|
||||||
@@ -457,6 +458,17 @@ class LightRAG:
|
|||||||
self._storages_status = StoragesStatus.FINALIZED
|
self._storages_status = StoragesStatus.FINALIZED
|
||||||
logger.debug("Finalized Storages")
|
logger.debug("Finalized Storages")
|
||||||
|
|
||||||
|
async def get_graph_labels(self):
|
||||||
|
text = await self.chunk_entity_relation_graph.get_all_labels()
|
||||||
|
return text
|
||||||
|
|
||||||
|
async def get_knowledge_graph(
|
||||||
|
self, nodel_label: str, max_depth: int
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||||
|
node_label=nodel_label, max_depth=max_depth
|
||||||
|
)
|
||||||
|
|
||||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||||
import_path = STORAGES[storage_name]
|
import_path = STORAGES[storage_name]
|
||||||
storage_class = lazy_external_import(import_path, storage_name)
|
storage_class = lazy_external_import(import_path, storage_name)
|
||||||
@@ -727,7 +739,7 @@ class LightRAG:
|
|||||||
|
|
||||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||||
try:
|
try:
|
||||||
new_kg = await extract_entities(
|
await extract_entities(
|
||||||
chunk,
|
chunk,
|
||||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
entity_vdb=self.entities_vdb,
|
entity_vdb=self.entities_vdb,
|
||||||
@@ -735,13 +747,6 @@ class LightRAG:
|
|||||||
llm_response_cache=self.llm_response_cache,
|
llm_response_cache=self.llm_response_cache,
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
)
|
)
|
||||||
if new_kg is None:
|
|
||||||
logger.info("No new entities or relationships extracted.")
|
|
||||||
else:
|
|
||||||
async with self._entity_lock:
|
|
||||||
logger.info("New entities or relationships extracted.")
|
|
||||||
self.chunk_entity_relation_graph = new_kg
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to extract entities and relationships")
|
logger.error("Failed to extract entities and relationships")
|
||||||
raise e
|
raise e
|
||||||
|
@@ -329,7 +329,7 @@ async def extract_entities(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
) -> BaseGraphStorage | None:
|
) -> None:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||||
@@ -522,16 +522,18 @@ async def extract_entities(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not len(all_entities_data) and not len(all_relationships_data):
|
if not (all_entities_data or all_relationships_data):
|
||||||
logger.warning(
|
logger.info("Didn't extract any entities and relationships.")
|
||||||
"Didn't extract any entities and relationships, maybe your LLM is not working"
|
return
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not len(all_entities_data):
|
if not all_entities_data:
|
||||||
logger.warning("Didn't extract any entities")
|
logger.info("Didn't extract any entities")
|
||||||
if not len(all_relationships_data):
|
if not all_relationships_data:
|
||||||
logger.warning("Didn't extract any relationships")
|
logger.info("Didn't extract any relationships")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"New entities or relationships extracted, entities:{all_entities_data}, relationships:{all_relationships_data}"
|
||||||
|
)
|
||||||
|
|
||||||
if entity_vdb is not None:
|
if entity_vdb is not None:
|
||||||
data_for_vdb = {
|
data_for_vdb = {
|
||||||
@@ -560,8 +562,6 @@ async def extract_entities(
|
|||||||
}
|
}
|
||||||
await relationships_vdb.upsert(data_for_vdb)
|
await relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
return knowledge_graph_inst
|
|
||||||
|
|
||||||
|
|
||||||
async def kg_query(
|
async def kg_query(
|
||||||
query: str,
|
query: str,
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional, Tuple, Dict, List
|
from typing import Optional, Tuple, Dict, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import networkx as nx
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
# Added automatic libraries install using pipmaster
|
# Added automatic libraries install using pipmaster
|
||||||
@@ -12,10 +12,7 @@ if not pm.is_installed("pyglm"):
|
|||||||
pm.install("pyglm")
|
pm.install("pyglm")
|
||||||
if not pm.is_installed("python-louvain"):
|
if not pm.is_installed("python-louvain"):
|
||||||
pm.install("python-louvain")
|
pm.install("python-louvain")
|
||||||
if not pm.is_installed("networkx"):
|
|
||||||
pm.install("networkx")
|
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
import moderngl
|
import moderngl
|
||||||
from imgui_bundle import imgui, immapp, hello_imgui
|
from imgui_bundle import imgui, immapp, hello_imgui
|
||||||
import community
|
import community
|
||||||
|
Reference in New Issue
Block a user