Merge branch 'HKUDS:main' into main

This commit is contained in:
Saifeddine ALOUI
2025-02-20 18:15:55 +01:00
committed by GitHub
15 changed files with 426 additions and 31 deletions

View File

@@ -70,7 +70,7 @@ def main():
), ),
vector_storage="FaissVectorDBStorage", vector_storage="FaissVectorDBStorage",
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": 0.3 # Your desired threshold "cosine_better_than_threshold": 0.2 # Your desired threshold
}, },
) )

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.1.7" __version__ = "1.1.11"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -1748,7 +1748,16 @@ def create_app(args):
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# query all graph labels
@app.get("/graph/label/list")
async def get_graph_labels():
return await rag.get_graph_labels()
# query all graph # query all graph
@app.get("/graphs")
async def get_knowledge_graph(label: str):
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
# Add Ollama API routes # Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k) ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")

View File

@@ -13,6 +13,7 @@ from typing import (
) )
import numpy as np import numpy as np
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph
load_dotenv() load_dotenv()
@@ -197,6 +198,16 @@ class BaseGraphStorage(StorageNameSpace, ABC):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
"""Get all labels in the graph.""" """Get all labels in the graph."""
@abstractmethod
async def get_all_labels(self) -> list[str]:
"""Get a knowledge graph of a node."""
@abstractmethod
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
class DocStatus(str, Enum): class DocStatus(str, Enum):
"""Document processing status""" """Document processing status"""

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union, final from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import numpy as np import numpy as np
import pipmaster as pm import pipmaster as pm
from lightrag.types import KnowledgeGraph
from tenacity import ( from tenacity import (
retry, retry,
@@ -59,6 +60,10 @@ class AGEQueryException(Exception):
@final @final
@dataclass @dataclass
class AGEStorage(BaseGraphStorage): class AGEStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
@@ -615,6 +620,14 @@ class AGEStorage(BaseGraphStorage):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# AGES handles persistence automatically # AGES handles persistence automatically
pass pass

View File

@@ -16,6 +16,7 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from lightrag.types import KnowledgeGraph
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
@@ -401,3 +402,11 @@ class GremlinStorage(BaseGraphStorage):
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -16,6 +16,7 @@ from ..base import (
) )
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pymongo"): if not pm.is_installed("pymongo"):
@@ -598,6 +599,197 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# QUERY # QUERY
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
#
async def get_all_labels(self) -> list[str]:
"""
Get all existing node _id in the database
Returns:
[id1, id2, ...] # Alphabetically sorted id list
"""
# Use MongoDB's distinct and aggregation to get all unique labels
pipeline = [
{"$group": {"_id": "$_id"}}, # Group by _id
{"$sort": {"_id": 1}}, # Sort alphabetically
]
cursor = self.collection.aggregate(pipeline)
labels = []
async for doc in cursor:
labels.append(doc["_id"])
return labels
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
"""
label = node_label
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
try:
if label == "*":
# Get all nodes and edges
async for node_doc in self.collection.find({}):
node_id = str(node_doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_doc.get("_id")],
properties={
k: v
for k, v in node_doc.items()
if k not in ["_id", "edges"]
},
)
)
seen_nodes.add(node_id)
# Process edges
for edge in node_doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
else:
# Verify if starting node exists
start_nodes = self.collection.find({"_id": label})
start_nodes_exist = await start_nodes.to_list(length=1)
if not start_nodes_exist:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# Use $graphLookup for traversal
pipeline = [
{
"$match": {"_id": label}
}, # Start with nodes having the specified label
{
"$graphLookup": {
"from": self._collection_name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_nodes",
}
},
]
async for doc in self.collection.aggregate(pipeline):
# Add the start node
node_id = str(doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[
doc.get(
"_id",
)
],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"edges",
"connected_nodes",
"depth",
]
},
)
)
seen_nodes.add(node_id)
# Add edges from start node
for edge in doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
# Add connected nodes and their edges
for connected in doc.get("connected_nodes", []):
node_id = str(connected["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[connected.get("_id")],
properties={
k: v
for k, v in connected.items()
if k not in ["_id", "edges", "depth"]
},
)
)
seen_nodes.add(node_id)
# Add edges from connected nodes
for edge in connected.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except PyMongoError as e:
logger.error(f"MongoDB query failed: {str(e)}")
return result
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Mongo handles persistence automatically # Mongo handles persistence automatically

View File

@@ -17,6 +17,7 @@ from tenacity import (
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@@ -468,6 +469,99 @@ class Neo4JStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Key fixes:
1. Include the starting node itself
2. Handle multi-label nodes
3. Clarify relationship directions
4. Add depth control
"""
label = node_label.strip('"')
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE) as session:
try:
main_query = ""
if label == "*":
main_query = """
MATCH (n)
WITH collect(DISTINCT n) AS nodes
MATCH ()-[r]-()
RETURN nodes, collect(DISTINCT r) AS relationships;
"""
else:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
return result
# Optimized query (including direction handling and self-loops)
main_query = f"""
MATCH (start:`{label}`)
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',
minLevel: 0,
maxLevel: {max_depth},
bfs: true
}})
YIELD nodes, relationships
RETURN nodes, relationships
"""
result_set = await session.run(main_query)
record = await result_set.single()
if record:
# Handle nodes (compatible with multi-label cases)
for node in record["nodes"]:
# Use node ID + label combination as unique identifier
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=list(node.labels),
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except neo4jExceptions.ClientError as e:
logger.error(f"APOC query failed: {str(e)}")
return await self._robust_fallback(label, max_depth)
return result
async def _robust_fallback( async def _robust_fallback(
self, label: str, max_depth: int self, label: str, max_depth: int
) -> Dict[str, List[Dict]]: ) -> Dict[str, List[Dict]]:
@@ -534,6 +628,31 @@ class Neo4JStorage(BaseGraphStorage):
await traverse(label, 0) await traverse(label, 0)
return result return result
async def get_all_labels(self) -> list[str]:
"""
Get all existing node labels in the database
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
"""
async with self._driver.session(database=self._DATABASE) as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH (n)
WITH DISTINCT labels(n) AS node_labels
UNWIND node_labels AS label
RETURN DISTINCT label
ORDER BY label
"""
result = await session.run(query)
labels = []
async for record in result:
labels.append(record["label"])
return labels
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError

View File

@@ -5,6 +5,7 @@ from typing import Any, final
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
) )
@@ -16,11 +17,12 @@ import pipmaster as pm
if not pm.is_installed("networkx"): if not pm.is_installed("networkx"):
pm.install("networkx") pm.install("networkx")
if not pm.is_installed("graspologic"): if not pm.is_installed("graspologic"):
pm.install("graspologic") pm.install("graspologic")
from graspologic import embed
import networkx as nx import networkx as nx
from graspologic import embed
@final @final
@@ -165,3 +167,11 @@ class NetworkXStorage(BaseGraphStorage):
for source, target in edges: for source, target in edges:
if self._graph.has_edge(source, target): if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target) self._graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -8,6 +8,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
from lightrag.types import KnowledgeGraph
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -669,6 +670,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -7,6 +7,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
from lightrag.types import KnowledgeGraph
import sys import sys
from tenacity import ( from tenacity import (
@@ -177,10 +178,12 @@ class PostgreSQLDB:
asyncpg.exceptions.UniqueViolationError, asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError, asyncpg.exceptions.DuplicateTableError,
) as e: ) as e:
if not upsert: if upsert:
logger.error(f"PostgreSQL, upsert error: {e}") print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database, sql:{sql}, data:{data}, error:{e}") logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise raise
@@ -1084,6 +1087,14 @@ class PGGraphStorage(BaseGraphStorage):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"] drop_sql = SQL_TEMPLATES["drop_vdb_entity"]

View File

@@ -5,6 +5,8 @@ from typing import Any, Union, final
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
@@ -558,6 +560,14 @@ class TiDBGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -47,6 +47,7 @@ from .utils import (
set_logger, set_logger,
encode_string_by_tiktoken, encode_string_by_tiktoken,
) )
from .types import KnowledgeGraph
# TODO: TO REMOVE @Yannick # TODO: TO REMOVE @Yannick
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -184,7 +185,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls.""" """Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field( embedding_cache_config: dict[str, Any] = field(
default={ default_factory=lambda: {
"enabled": False, "enabled": False,
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
"use_llm_check": False, "use_llm_check": False,
@@ -457,6 +458,17 @@ class LightRAG:
self._storages_status = StoragesStatus.FINALIZED self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages") logger.debug("Finalized Storages")
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text
async def get_knowledge_graph(
self, nodel_label: str, max_depth: int
) -> KnowledgeGraph:
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=nodel_label, max_depth=max_depth
)
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name] import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name) storage_class = lazy_external_import(import_path, storage_name)
@@ -727,7 +739,7 @@ class LightRAG:
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try: try:
new_kg = await extract_entities( await extract_entities(
chunk, chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, entity_vdb=self.entities_vdb,
@@ -735,13 +747,6 @@ class LightRAG:
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
global_config=asdict(self), global_config=asdict(self),
) )
if new_kg is None:
logger.info("No new entities or relationships extracted.")
else:
async with self._entity_lock:
logger.info("New entities or relationships extracted.")
self.chunk_entity_relation_graph = new_kg
except Exception as e: except Exception as e:
logger.error("Failed to extract entities and relationships") logger.error("Failed to extract entities and relationships")
raise e raise e

View File

@@ -329,7 +329,7 @@ async def extract_entities(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
global_config: dict[str, str], global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> BaseGraphStorage | None: ) -> None:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[ enable_llm_cache_for_entity_extract: bool = global_config[
@@ -522,16 +522,18 @@ async def extract_entities(
] ]
) )
if not len(all_entities_data) and not len(all_relationships_data): if not (all_entities_data or all_relationships_data):
logger.warning( logger.info("Didn't extract any entities and relationships.")
"Didn't extract any entities and relationships, maybe your LLM is not working" return
)
return None
if not len(all_entities_data): if not all_entities_data:
logger.warning("Didn't extract any entities") logger.info("Didn't extract any entities")
if not len(all_relationships_data): if not all_relationships_data:
logger.warning("Didn't extract any relationships") logger.info("Didn't extract any relationships")
logger.info(
f"New entities or relationships extracted, entities:{all_entities_data}, relationships:{all_relationships_data}"
)
if entity_vdb is not None: if entity_vdb is not None:
data_for_vdb = { data_for_vdb = {
@@ -560,8 +562,6 @@ async def extract_entities(
} }
await relationships_vdb.upsert(data_for_vdb) await relationships_vdb.upsert(data_for_vdb)
return knowledge_graph_inst
async def kg_query( async def kg_query(
query: str, query: str,

View File

@@ -1,6 +1,6 @@
from typing import Optional, Tuple, Dict, List from typing import Optional, Tuple, Dict, List
import numpy as np import numpy as np
import networkx as nx
import pipmaster as pm import pipmaster as pm
# Added automatic libraries install using pipmaster # Added automatic libraries install using pipmaster
@@ -12,10 +12,7 @@ if not pm.is_installed("pyglm"):
pm.install("pyglm") pm.install("pyglm")
if not pm.is_installed("python-louvain"): if not pm.is_installed("python-louvain"):
pm.install("python-louvain") pm.install("python-louvain")
if not pm.is_installed("networkx"):
pm.install("networkx")
import networkx as nx
import moderngl import moderngl
from imgui_bundle import imgui, immapp, hello_imgui from imgui_bundle import imgui, immapp, hello_imgui
import community import community