Merge branch 'main' into multi-language

This commit is contained in:
baoheping
2025-03-08 19:04:32 +08:00
committed by GitHub
40 changed files with 3048 additions and 246 deletions

View File

@@ -849,6 +849,76 @@ All operations are available in both synchronous and asynchronous versions. The
These operations maintain data consistency across both the graph database and vector database components, ensuring your knowledge graph remains coherent. These operations maintain data consistency across both the graph database and vector database components, ensuring your knowledge graph remains coherent.
## Entity Merging
<details>
<summary> <b>Merge Entities and Their Relationships</b> </summary>
LightRAG now supports merging multiple entities into a single entity, automatically handling all relationships:
```python
# Basic entity merging
rag.merge_entities(
source_entities=["Artificial Intelligence", "AI", "Machine Intelligence"],
target_entity="AI Technology"
)
```
With custom merge strategy:
```python
# Define custom merge strategy for different fields
rag.merge_entities(
source_entities=["John Smith", "Dr. Smith", "J. Smith"],
target_entity="John Smith",
merge_strategy={
"description": "concatenate", # Combine all descriptions
"entity_type": "keep_first", # Keep the entity type from the first entity
"source_id": "join_unique" # Combine all unique source IDs
}
)
```
With custom target entity data:
```python
# Specify exact values for the merged entity
rag.merge_entities(
source_entities=["New York", "NYC", "Big Apple"],
target_entity="New York City",
target_entity_data={
"entity_type": "LOCATION",
"description": "New York City is the most populous city in the United States.",
}
)
```
Advanced usage combining both approaches:
```python
# Merge company entities with both strategy and custom data
rag.merge_entities(
source_entities=["Microsoft Corp", "Microsoft Corporation", "MSFT"],
target_entity="Microsoft",
merge_strategy={
"description": "concatenate", # Combine all descriptions
"source_id": "join_unique" # Combine source IDs
},
target_entity_data={
"entity_type": "ORGANIZATION",
}
)
```
When merging entities:
* All relationships from source entities are redirected to the target entity
* Duplicate relationships are intelligently merged
* Self-relationships (loops) are prevented
* Source entities are removed after merging
* Relationship weights and attributes are preserved
</details>
## Cache ## Cache
<details> <details>

View File

@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
### Redis ### Redis
REDIS_URI=redis://localhost:6379 REDIS_URI=redis://localhost:6379
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/login,/health # white list

View File

@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
pip install lightrag-hku pip install lightrag-hku
``` ```
## Authentication Endpoints
### JWT Authentication Mechanism
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
```bash
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
```
## API Endpoints ## API Endpoints
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit: All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:

41
lightrag/api/auth.py Normal file
View File

@@ -0,0 +1,41 @@
import os
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException, status
from pydantic import BaseModel
class TokenPayload(BaseModel):
sub: str
exp: datetime
class AuthHandler:
def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
def create_token(self, username: str) -> str:
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
payload = TokenPayload(sub=username, exp=expire)
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
def validate_token(self, token: str) -> str:
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
expire_timestamp = payload["exp"]
expire_time = datetime.utcfromtimestamp(expire_timestamp)
if datetime.utcnow() > expire_time:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
return payload["sub"]
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
auth_handler = AuthHandler()

View File

@@ -2,10 +2,7 @@
LightRAG FastAPI Server LightRAG FastAPI Server
""" """
from fastapi import ( from fastapi import FastAPI, Depends, HTTPException, status
FastAPI,
Depends,
)
import asyncio import asyncio
import os import os
import logging import logging
@@ -19,7 +16,7 @@ from ascii_colors import ASCIIColors
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dotenv import load_dotenv from dotenv import load_dotenv
from .utils_api import ( from lightrag.api.utils_api import (
get_api_key_dependency, get_api_key_dependency,
parse_args, parse_args,
get_default_host, get_default_host,
@@ -29,14 +26,14 @@ from lightrag import LightRAG
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from .routers.document_routes import ( from lightrag.api.routers.document_routes import (
DocumentManager, DocumentManager,
create_document_routes, create_document_routes,
run_scanning_process, run_scanning_process,
) )
from .routers.query_routes import create_query_routes from lightrag.api.routers.query_routes import create_query_routes
from .routers.graph_routes import create_graph_routes from lightrag.api.routers.graph_routes import create_graph_routes
from .routers.ollama_api import OllamaAPI from lightrag.api.routers.ollama_api import OllamaAPI
from lightrag.utils import logger, set_verbose_debug from lightrag.utils import logger, set_verbose_debug
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
@@ -45,9 +42,13 @@ from lightrag.kg.shared_storage import (
initialize_pipeline_status, initialize_pipeline_status,
get_all_update_flags_status, get_all_update_flags_status,
) )
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
# Load environment variables # Load environment variables
load_dotenv(override=True) # Updated to use the .env that is inside the current folder
# This update allows the user to put a different.env file for each lightrag folder
load_dotenv(".env", override=True)
# Initialize config parser # Initialize config parser
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -370,6 +371,27 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k) ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured",
)
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
return {
"access_token": auth_handler.create_token(username),
"token_type": "bearer",
}
@app.get("/health", dependencies=[Depends(optional_api_key)]) @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""

View File

@@ -1,10 +1,20 @@
aiofiles aiofiles
ascii_colors ascii_colors
asyncpg
distro
fastapi fastapi
httpcore
httpx
jiter
numpy numpy
openai
passlib[bcrypt]
pipmaster pipmaster
PyJWT
python-dotenv python-dotenv
python-jose[cryptography]
python-multipart python-multipart
pytz
tenacity tenacity
tiktoken tiktoken
uvicorn uvicorn

View File

@@ -16,10 +16,13 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(
router = APIRouter(prefix="/documents", tags=["documents"]) prefix="/documents",
tags=["documents"],
dependencies=[Depends(get_auth_dependency())],
)
# Temporary file prefix # Temporary file prefix
temp_prefix = "__tmp__" temp_prefix = "__tmp__"

View File

@@ -3,12 +3,11 @@ This module contains all graph-related routes for the LightRAG API.
""" """
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(tags=["graph"]) router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
def create_graph_routes(rag, api_key: Optional[str] = None): def create_graph_routes(rag, api_key: Optional[str] = None):
@@ -25,23 +24,33 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
return await rag.get_graph_labels() return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)]) @router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3): async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
):
""" """
Retrieve a connected subgraph of nodes where the label includes the specified label. Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence 1. min_degree does not affect nodes directly connected to the matching nodes
2. Followed by nodes directly connected to the matching nodes 2. Label matching nodes take precedence
3. Finally, the degree of the nodes 3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args: Args:
label (str): Label to get knowledge graph for label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3. max_depth (int, optional): Maximum depth of graph. Defaults to 3.
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
Returns: Returns:
Dict[str, List[str]]: Knowledge graph for label Dict[str, List[str]]: Knowledge graph for label
""" """
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth) return await rag.get_knowledge_graph(
node_label=label,
max_depth=max_depth,
inclusive=inclusive,
min_degree=min_degree,
)
return router return router

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception from ascii_colors import trace_exception
router = APIRouter(tags=["query"]) router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
class QueryRequest(BaseModel): class QueryRequest(BaseModel):

View File

@@ -9,6 +9,11 @@ import signal
import pipmaster as pm import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.api.utils_api import parse_args, display_splash_screen
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
from dotenv import load_dotenv
# Updated to use the .env that is inside the current folder
# This update allows the user to put a different.env file for each lightrag folder
load_dotenv(".env")
def check_and_install_dependencies(): def check_and_install_dependencies():

View File

@@ -9,10 +9,11 @@ import sys
import logging import logging
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from fastapi import HTTPException, Security from fastapi import HTTPException, Security, Depends, Request
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
# Load environment variables # Load environment variables
load_dotenv(override=True) load_dotenv(override=True)
@@ -31,6 +32,24 @@ class OllamaServerInfos:
ollama_server_infos = OllamaServerInfos() ollama_server_infos = OllamaServerInfos()
def get_auth_dependency():
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
async def dependency(
request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
):
if request.url.path in whitelist:
return
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
return
auth_handler.validate_token(token)
return dependency
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
""" """
Create an API key dependency for route protection. Create an API key dependency for route protection.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -5,8 +5,8 @@
<link rel="icon" type="image/svg+xml" href="./logo.png" /> <link rel="icon" type="image/svg+xml" href="./logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title> <title>Lightrag</title>
<script type="module" crossorigin src="./assets/index-DbuMPJAD.js"></script> <script type="module" crossorigin src="./assets/index-CJz72b6Q.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-rP-YlyR1.css"> <link rel="stylesheet" crossorigin href="./assets/index-CH-3l4_Z.css">
</head> </head>
<body> <body>
<div id="root"></div> <div id="root"></div>

View File

@@ -204,7 +204,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 3
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node.""" """Retrieve a subgraph of the knowledge graph starting from a given node."""

View File

@@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union, final from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import numpy as np import numpy as np
import pipmaster as pm import pipmaster as pm
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from tenacity import ( from tenacity import (
retry, retry,
@@ -613,20 +613,260 @@ class AGEStorage(BaseGraphStorage):
await self._driver.putconn(connection) await self._driver.putconn(connection)
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError """Delete a node with the specified label
Args:
node_id: The label of the node to delete
"""
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`)
DETACH DELETE n
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
try:
await self._query(query, **params)
logger.debug(f"Deleted node with label '{entity_name_label}'")
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node labels to be deleted
"""
for node in nodes:
await self.delete_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
entity_name_label_source = source.strip('"')
entity_name_label_target = target.strip('"')
query = """
MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`)
DELETE r
"""
params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
}
try:
await self._query(query, **params)
logger.debug(
f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'"
)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError """Embed nodes using the specified algorithm
Args:
algorithm: Name of the embedding algorithm
Returns:
tuple: (embedding matrix, list of node identifiers)
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError """Get all node labels in the database
Returns:
["label1", "label2", ...] # Alphabetically sorted label list
"""
query = """
MATCH (n)
RETURN DISTINCT labels(n) AS node_labels
"""
results = await self._query(query)
all_labels = []
for record in results:
if record and "node_labels" in record:
for label in record["node_labels"]:
if label:
# Decode label
decoded_label = AGEStorage._decode_graph_label(label)
all_labels.append(decoded_label)
# Remove duplicates and sort
return sorted(list(set(all_labels)))
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
raise NotImplementedError """
Retrieve a connected subgraph of nodes where the label includes the specified 'node_label'.
Maximum number of nodes is constrained by the environment variable 'MAX_GRAPH_NODES' (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence (nodes containing the specified label string)
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Args:
node_label: String to match in node labels (will match any node containing this string in its label)
max_depth: Maximum depth of the graph. Defaults to 5.
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
"""
max_graph_nodes = int(os.getenv("MAX_GRAPH_NODES", 1000))
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
# Handle special case for "*" label
if node_label == "*":
# Query all nodes and sort by degree
query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, count(r) AS degree
ORDER BY degree DESC
LIMIT {max_nodes}
RETURN n, degree
"""
params = {"max_nodes": max_graph_nodes}
nodes_result = await self._query(query, **params)
# Add nodes to result
node_ids = []
for record in nodes_result:
if "n" in record:
node = record["n"]
node_id = str(node.get("id", ""))
if node_id not in seen_nodes:
node_properties = {k: v for k, v in node.items()}
node_label = node.get("label", "")
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_label],
properties=node_properties,
)
)
seen_nodes.add(node_id)
node_ids.append(node_id)
# Query edges between these nodes
if node_ids:
edges_query = """
MATCH (a)-[r]->(b)
WHERE a.id IN {node_ids} AND b.id IN {node_ids}
RETURN a, r, b
"""
edges_params = {"node_ids": node_ids}
edges_result = await self._query(edges_query, **edges_params)
# Add edges to result
for record in edges_result:
if "r" in record and "a" in record and "b" in record:
source = record["a"].get("id", "")
target = record["b"].get("id", "")
edge_id = f"{source}-{target}"
if edge_id not in seen_edges:
edge_properties = {k: v for k, v in record["r"].items()}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=source,
target=target,
properties=edge_properties,
)
)
seen_edges.add(edge_id)
else:
# For specific label, use partial matching
entity_name_label = node_label.strip('"')
encoded_label = AGEStorage._encode_graph_label(entity_name_label)
# Find matching start nodes
start_query = """
MATCH (n:`{label}`)
RETURN n
"""
start_params = {"label": encoded_label}
start_nodes = await self._query(start_query, **start_params)
if not start_nodes:
logger.warning(f"No nodes found with label '{entity_name_label}'!")
return result
# Traverse graph from each start node
for start_node_record in start_nodes:
if "n" in start_node_record:
# Use BFS to traverse graph
query = """
MATCH (start:`{label}`)
CALL {
MATCH path = (start)-[*0..{max_depth}]->(n)
RETURN nodes(path) AS path_nodes, relationships(path) AS path_rels
}
RETURN DISTINCT path_nodes, path_rels
"""
params = {"label": encoded_label, "max_depth": max_depth}
results = await self._query(query, **params)
# Extract nodes and edges from results
for record in results:
if "path_nodes" in record:
# Process nodes
for node in record["path_nodes"]:
node_id = str(node.get("id", ""))
if (
node_id not in seen_nodes
and len(seen_nodes) < max_graph_nodes
):
node_properties = {k: v for k, v in node.items()}
node_label = node.get("label", "")
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_label],
properties=node_properties,
)
)
seen_nodes.add(node_id)
if "path_rels" in record:
# Process edges
for rel in record["path_rels"]:
source = str(rel.get("start_id", ""))
target = str(rel.get("end_id", ""))
edge_id = f"{source}-{target}"
if edge_id not in seen_edges:
edge_properties = {k: v for k, v in rel.items()}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=rel.get("label", "DIRECTED"),
source=source,
target=target,
properties=edge_properties,
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# AGES handles persistence automatically # AGES handles persistence automatically

View File

@@ -193,7 +193,79 @@ class ChromaVectorDBStorage(BaseVectorStorage):
pass pass
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError """Delete an entity by its ID.
Args:
entity_name: The ID of the entity to delete
"""
try:
logger.info(f"Deleting entity with ID {entity_name} from {self.namespace}")
self._collection.delete(ids=[entity_name])
except Exception as e:
logger.error(f"Error during entity deletion: {str(e)}")
raise
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError """Delete an entity and its relations by ID.
In vector DB context, this is equivalent to delete_entity.
Args:
entity_name: The ID of the entity to delete
"""
await self.delete_entity(entity_name)
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
try:
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
self._collection.delete(ids=ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Get all records from the collection
# Since ChromaDB doesn't directly support prefix search on IDs,
# we'll get all records and filter in Python
results = self._collection.get(
include=["metadatas", "documents", "embeddings"]
)
matching_records = []
# Filter records where ID starts with the prefix
for i, record_id in enumerate(results["ids"]):
if record_id.startswith(prefix):
matching_records.append(
{
"id": record_id,
"content": results["documents"][i],
"vector": results["embeddings"][i],
**results["metadatas"][i],
}
)
logger.debug(
f"Found {len(matching_records)} records with prefix '{prefix}'"
)
return matching_records
except Exception as e:
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
raise

View File

@@ -371,3 +371,24 @@ class FaissVectorDBStorage(BaseVectorStorage):
return False # Return error return False # Return error
return True # Return success return True # Return success
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
matching_records = []
# Search for records with IDs starting with the prefix
for faiss_id, meta in self._id_to_meta.items():
if "__id__" in meta and meta["__id__"].startswith(prefix):
# Create a copy of all metadata and add "id" field
record = {**meta, "id": meta["__id__"]}
matching_records.append(record)
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
return matching_records

View File

@@ -16,7 +16,7 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
@@ -396,17 +396,302 @@ class GremlinStorage(BaseGraphStorage):
print("Implemented but never called.") print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError """Delete a node with the specified entity_name
Args:
node_id: The entity_name of the node to delete
"""
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.drop()
"""
try:
await self._query(query)
logger.debug(
"{%s}: Deleted node with entity_name '%s'",
inspect.currentframe().f_code.co_name,
entity_name,
)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError """
Embed nodes using the specified algorithm.
Currently, only node2vec is supported but never called.
Args:
algorithm: The name of the embedding algorithm to use
Returns:
A tuple of (embeddings, node_ids)
Raises:
NotImplementedError: If the specified algorithm is not supported
ValueError: If the algorithm is not supported
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError """
Get all node entity_names in the graph
Returns:
[entity_name1, entity_name2, ...] # Alphabetically sorted entity_name list
"""
query = f"""g
.V().has('graph', {self.graph_name})
.values('entity_name')
.dedup()
.order()
"""
try:
result = await self._query(query)
labels = result if result else []
logger.debug(
"{%s}: Retrieved %d labels",
inspect.currentframe().f_code.co_name,
len(labels),
)
return labels
except Exception as e:
logger.error(f"Error retrieving labels: {str(e)}")
return []
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
raise NotImplementedError """
Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
Args:
node_label: Entity name of the starting node
max_depth: Maximum depth of the subgraph
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
entity_name = GremlinStorage._fix_name(node_label)
# Handle special case for "*" label
if node_label == "*":
# For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES)
query = f"""g
.V().has('graph', {self.graph_name})
.limit({MAX_GRAPH_NODES})
.elementMap()
"""
nodes_result = await self._query(query)
# Add nodes to result
for node_data in nodes_result:
node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes:
continue
# Create node with properties
node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append(
KnowledgeGraphNode(
id=str(node_id),
labels=[str(node_id)],
properties=node_properties,
)
)
seen_nodes.add(str(node_id))
# Get and add edges
if nodes_result:
query = f"""g
.V().has('graph', {self.graph_name})
.limit({MAX_GRAPH_NODES})
.outE()
.inV().has('graph', {self.graph_name})
.limit({MAX_GRAPH_NODES})
.path()
.by(elementMap())
.by(elementMap())
.by(elementMap())
"""
edges_result = await self._query(query)
for path in edges_result:
if len(path) >= 3: # source -> edge -> target
source = path[0]
edge_data = path[1]
target = path[2]
source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges:
continue
# Create edge with properties
edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source_id),
target=str(target_id),
properties=edge_properties,
)
)
seen_edges.add(edge_id)
else:
# Search for specific node and get its neighborhood
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.repeat(__.both().simplePath().dedup())
.times({max_depth})
.emit()
.dedup()
.limit({MAX_GRAPH_NODES})
.elementMap()
"""
nodes_result = await self._query(query)
# Add nodes to result
for node_data in nodes_result:
node_id = node_data.get("entity_name", str(node_data.get("id", "")))
if str(node_id) in seen_nodes:
continue
# Create node with properties
node_properties = {
k: v for k, v in node_data.items() if k not in ["id", "label"]
}
result.nodes.append(
KnowledgeGraphNode(
id=str(node_id),
labels=[str(node_id)],
properties=node_properties,
)
)
seen_nodes.add(str(node_id))
# Get edges between the nodes in the result
if nodes_result:
node_ids = [
n.get("entity_name", str(n.get("id", ""))) for n in nodes_result
]
node_ids_query = ", ".join(
[GremlinStorage._to_value_map(nid) for nid in node_ids]
)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', within({node_ids_query}))
.outE()
.where(inV().has('graph', {self.graph_name})
.has('entity_name', within({node_ids_query})))
.path()
.by(elementMap())
.by(elementMap())
.by(elementMap())
"""
edges_result = await self._query(query)
for path in edges_result:
if len(path) >= 3: # source -> edge -> target
source = path[0]
edge_data = path[1]
target = path[2]
source_id = source.get("entity_name", str(source.get("id", "")))
target_id = target.get("entity_name", str(target.get("id", "")))
edge_id = f"{source_id}-{target_id}"
if edge_id in seen_edges:
continue
# Create edge with properties
edge_properties = {
k: v
for k, v in edge_data.items()
if k not in ["id", "label"]
}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source_id),
target=str(target_id),
properties=edge_properties,
)
)
seen_edges.add(edge_id)
logger.info(
"Subgraph query successful | Node count: %d | Edge count: %d",
len(result.nodes),
len(result.edges),
)
return result
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node entity_names to be deleted
"""
for node in nodes:
await self.delete_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
entity_name_source = GremlinStorage._fix_name(source)
entity_name_target = GremlinStorage._fix_name(target)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.outE()
.where(inV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target}))
.drop()
"""
try:
await self._query(query)
logger.debug(
"{%s}: Deleted edge from '%s' to '%s'",
inspect.currentframe().f_code.co_name,
entity_name_source,
entity_name_target,
)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise

View File

@@ -3,7 +3,7 @@ import os
from typing import Any, final from typing import Any, final
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from lightrag.utils import logger from lightrag.utils import logger, compute_mdhash_id
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm import pipmaster as pm
@@ -124,7 +124,110 @@ class MilvusVectorDBStorage(BaseVectorStorage):
pass pass
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError """Delete an entity from the vector database
Args:
entity_name: The name of the entity to delete
"""
try:
# Compute entity ID from name
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity from Milvus collection
result = self._client.delete(
collection_name=self.namespace, pks=[entity_id]
)
if result and result.get("delete_count", 0) > 0:
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError """Delete all relations associated with an entity
Args:
entity_name: The name of the entity whose relations should be deleted
"""
try:
# Search for relations where entity is either source or target
expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"'
# Find all relations involving this entity
results = self._client.query(
collection_name=self.namespace, filter=expr, output_fields=["id"]
)
if not results or len(results) == 0:
logger.debug(f"No relations found for entity {entity_name}")
return
# Extract IDs of relations to delete
relation_ids = [item["id"] for item in results]
logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations
if relation_ids:
delete_result = self._client.delete(
collection_name=self.namespace, pks=relation_ids
)
logger.debug(
f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
)
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
try:
# Delete vectors by IDs
result = self._client.delete(collection_name=self.namespace, pks=ids)
if result and result.get("delete_count", 0) > 0:
logger.debug(
f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
)
else:
logger.debug(f"No vectors were deleted from {self.namespace}")
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use Milvus query with expression to find IDs with the given prefix
expression = f'id like "{prefix}%"'
results = self._client.query(
collection_name=self.namespace,
filter=expression,
output_fields=list(self.meta_fields) + ["id"],
)
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
return results
except Exception as e:
logger.error(f"Error searching for records with prefix '{prefix}': {e}")
return []

View File

@@ -15,7 +15,7 @@ from ..base import (
DocStatusStorage, DocStatusStorage,
) )
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm import pipmaster as pm
@@ -333,7 +333,7 @@ class MongoGraphStorage(BaseGraphStorage):
Check if there's a direct single-hop edge from source_node_id to target_node_id. Check if there's a direct single-hop edge from source_node_id to target_node_id.
We'll do a $graphLookup with maxDepth=0 from the source node—meaning We'll do a $graphLookup with maxDepth=0 from the source node—meaning
Look up zero expansions. Actually, for a direct edge check, we can do maxDepth=1 "Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1
and then see if the target node is in the "reachableNodes" at depth=0. and then see if the target node is in the "reachableNodes" at depth=0.
But typically for a direct edge, we might just do a find_one. But typically for a direct edge, we might just do a find_one.
@@ -795,6 +795,50 @@ class MongoGraphStorage(BaseGraphStorage):
# Mongo handles persistence automatically # Mongo handles persistence automatically
pass pass
async def remove_nodes(self, nodes: list[str]) -> None:
"""Delete multiple nodes
Args:
nodes: List of node IDs to be deleted
"""
logger.info(f"Deleting {len(nodes)} nodes")
if not nodes:
return
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes)
await self.collection.update_many(
{}, {"$pull": {"edges": {"target": {"$in": nodes}}}}
)
# 2. Delete the node documents
await self.collection.delete_many({"_id": {"$in": nodes}})
logger.debug(f"Successfully deleted nodes: {nodes}")
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
logger.info(f"Deleting {len(edges)} edges")
if not edges:
return
update_tasks = []
for source, target in edges:
# Remove edge pointing to target from source node's edges array
update_tasks.append(
self.collection.update_one(
{"_id": source}, {"$pull": {"edges": {"target": target}}}
)
)
if update_tasks:
await asyncio.gather(*update_tasks)
logger.debug(f"Successfully deleted edges: {edges}")
@final @final
@dataclass @dataclass
@@ -932,11 +976,100 @@ class MongoVectorDBStorage(BaseVectorStorage):
# Mongo handles persistence automatically # Mongo handles persistence automatically
pass pass
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
if not ids:
return
try:
result = await self._data.delete_many({"_id": {"$in": ids}})
logger.debug(
f"Successfully deleted {result.deleted_count} vectors from {self.namespace}"
)
except PyMongoError as e:
logger.error(
f"Error while deleting vectors from {self.namespace}: {str(e)}"
)
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError """Delete an entity by its name
Args:
entity_name: Name of the entity to delete
"""
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
result = await self._data.delete_one({"_id": entity_id})
if result.deleted_count > 0:
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except PyMongoError as e:
logger.error(f"Error deleting entity {entity_name}: {str(e)}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError """Delete all relations associated with an entity
Args:
entity_name: Name of the entity whose relations should be deleted
"""
try:
# Find relations where entity appears as source or target
relations_cursor = self._data.find(
{"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]}
)
relations = await relations_cursor.to_list(length=None)
if not relations:
logger.debug(f"No relations found for entity {entity_name}")
return
# Extract IDs of relations to delete
relation_ids = [relation["_id"] for relation in relations]
logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations
result = await self._data.delete_many({"_id": {"$in": relation_ids}})
logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}")
except PyMongoError as e:
logger.error(f"Error deleting relations for {entity_name}: {str(e)}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use MongoDB regex to find documents where _id starts with the prefix
cursor = self._data.find({"_id": {"$regex": f"^{prefix}"}})
matching_records = await cursor.to_list(length=None)
# Format results
results = [{**doc, "id": doc["_id"]} for doc in matching_records]
logger.debug(
f"Found {len(results)} records with prefix '{prefix}' in {self.namespace}"
)
return results
except PyMongoError as e:
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
return []
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):

View File

@@ -236,3 +236,23 @@ class NanoVectorDBStorage(BaseVectorStorage):
return False # Return error return False # Return error
return True # Return success return True # Return success
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
storage = await self.client_storage
matching_records = []
# Search for records with IDs starting with the prefix
for record in storage["data"]:
if "__id__" in record and record["__id__"].startswith(prefix):
matching_records.append({**record, "id": record["__id__"]})
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
return matching_records

View File

@@ -232,19 +232,26 @@ class NetworkXStorage(BaseGraphStorage):
return sorted(list(labels)) return sorted(list(labels))
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence 1. min_degree does not affect nodes directly connected to the matching nodes
2. Followed by nodes directly connected to the matching nodes 2. Label matching nodes take precedence
3. Finally, the degree of the nodes 3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges
@@ -255,6 +262,10 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph() graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # For "*", return the entire graph including all nodes and edges
@@ -262,11 +273,16 @@ class NetworkXStorage(BaseGraphStorage):
graph.copy() graph.copy()
) # Create a copy to avoid modifying the original graph ) # Create a copy to avoid modifying the original graph
else: else:
# Find nodes with matching node id (partial match) # Find nodes with matching node id based on search_mode
nodes_to_explore = [] nodes_to_explore = []
for n, attr in graph.nodes(data=True): for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching node_str = str(n)
nodes_to_explore.append(n) if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore: if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}") logger.warning(f"No nodes found with label {node_label}")
@@ -277,26 +293,37 @@ class NetworkXStorage(BaseGraphStorage):
for start_node in nodes_to_explore: for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph) combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Get start nodes and direct connected nodes
if nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(
combined_subgraph.neighbors(start_node)
)
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
subgraph = combined_subgraph subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES: if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
def priority_key(node_item): def priority_key(node_item):
node, degree = node_item node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0) # Priority order: start(2) > directly connected(1) > other nodes(0)
@@ -356,7 +383,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type="RELATED", type="DIRECTED",
source=str(source), source=str(source),
target=str(target), target=str(target),
properties=edge_data, properties=edge_data,

View File

@@ -8,7 +8,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -442,11 +442,92 @@ class OracleVectorDBStorage(BaseVectorStorage):
# Oracles handles persistence automatically # Oracles handles persistence automatically
pass pass
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
if not ids:
return
try:
SQL = SQL_TEMPLATES["delete_vectors"].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
await self.db.execute(SQL, params)
logger.info(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError """Delete entity by name
Args:
entity_name: Name of the entity to delete
"""
try:
SQL = SQL_TEMPLATES["delete_entity"]
params = {"workspace": self.db.workspace, "entity_name": entity_name}
await self.db.execute(SQL, params)
logger.info(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
raise
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError """Delete all relations connected to an entity
Args:
entity_name: Name of the entity whose relations should be deleted
"""
try:
SQL = SQL_TEMPLATES["delete_entity_relations"]
params = {"workspace": self.db.workspace, "entity_name": entity_name}
await self.db.execute(SQL, params)
logger.info(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
raise
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Determine the appropriate table based on namespace
table_name = namespace_to_table_name(self.namespace)
# Create SQL query to find records with IDs starting with prefix
search_sql = f"""
SELECT * FROM {table_name}
WHERE workspace = :workspace
AND id LIKE :prefix_pattern
ORDER BY id
"""
params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
# Execute query and get results
results = await self.db.query(search_sql, params, multirows=True)
logger.debug(
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
)
return results or []
except Exception as e:
logger.error(f"Error searching records with prefix '{prefix}': {e}")
return []
@final @final
@@ -668,15 +749,266 @@ class OracleGraphStorage(BaseGraphStorage):
return res return res
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError """Delete a node from the graph
Args:
node_id: ID of the node to delete
"""
try:
# First delete all relations connected to this node
delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
params_relations = {"workspace": self.db.workspace, "entity_name": node_id}
await self.db.execute(delete_relations_sql, params_relations)
# Then delete the node itself
delete_node_sql = SQL_TEMPLATES["delete_entity"]
params_node = {"workspace": self.db.workspace, "entity_name": node_id}
await self.db.execute(delete_node_sql, params_node)
logger.info(
f"Successfully deleted node {node_id} and all its relationships"
)
except Exception as e:
logger.error(f"Error deleting node {node_id}: {e}")
raise
async def remove_nodes(self, nodes: list[str]) -> None:
"""Delete multiple nodes from the graph
Args:
nodes: List of node IDs to be deleted
"""
if not nodes:
return
try:
for node in nodes:
# For each node, first delete all its relationships
delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
params_relations = {"workspace": self.db.workspace, "entity_name": node}
await self.db.execute(delete_relations_sql, params_relations)
# Then delete the node itself
delete_node_sql = SQL_TEMPLATES["delete_entity"]
params_node = {"workspace": self.db.workspace, "entity_name": node}
await self.db.execute(delete_node_sql, params_node)
logger.info(
f"Successfully deleted {len(nodes)} nodes and their relationships"
)
except Exception as e:
logger.error(f"Error during batch node deletion: {e}")
raise
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
"""Delete multiple edges from the graph
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
if not edges:
return
try:
for source, target in edges:
# Check if the edge exists before attempting to delete
if await self.has_edge(source, target):
# Delete the edge using a SQL query that matches both source and target
delete_edge_sql = """
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE workspace = :workspace
AND source_name = :source_name
AND target_name = :target_name
"""
params = {
"workspace": self.db.workspace,
"source_name": source,
"target_name": target,
}
await self.db.execute(delete_edge_sql, params)
logger.info(f"Successfully deleted {len(edges)} edges from the graph")
except Exception as e:
logger.error(f"Error during batch edge deletion: {e}")
raise
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError """Get all unique entity types (labels) in the graph
Returns:
List of unique entity types/labels
"""
try:
SQL = """
SELECT DISTINCT entity_type
FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY entity_type
"""
params = {"workspace": self.db.workspace}
results = await self.db.query(SQL, params, multirows=True)
if results:
labels = [row["entity_type"] for row in results]
return labels
else:
return []
except Exception as e:
logger.error(f"Error retrieving entity types: {e}")
return []
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
raise NotImplementedError """Retrieve a connected subgraph starting from nodes matching the given label
Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable.
Prioritizes nodes by:
1. Nodes matching the specified label
2. Nodes directly connected to matching nodes
3. Node degree (number of connections)
Args:
node_label: Label to match for starting nodes (use "*" for all nodes)
max_depth: Maximum depth of traversal from starting nodes
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
try:
# Define maximum number of nodes to return
max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000))
if node_label == "*":
# For "*" label, get all nodes up to the limit
nodes_sql = """
SELECT name, entity_type, description, source_chunk_id
FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY id
FETCH FIRST :limit ROWS ONLY
"""
nodes_params = {
"workspace": self.db.workspace,
"limit": max_graph_nodes,
}
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
else:
# For specific label, find matching nodes and related nodes
nodes_sql = """
WITH matching_nodes AS (
SELECT name
FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%')
)
SELECT n.name, n.entity_type, n.description, n.source_chunk_id,
CASE
WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2
WHEN EXISTS (
SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e
WHERE workspace = :workspace
AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes))
OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes)))
) THEN 1
ELSE 0
END AS priority,
(SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e
WHERE workspace = :workspace
AND (e.source_name = n.name OR e.target_name = n.name)) AS degree
FROM LIGHTRAG_GRAPH_NODES n
WHERE workspace = :workspace
ORDER BY priority DESC, degree DESC
FETCH FIRST :limit ROWS ONLY
"""
nodes_params = {
"workspace": self.db.workspace,
"node_label": node_label,
"limit": max_graph_nodes,
}
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
if not nodes:
logger.warning(f"No nodes found matching '{node_label}'")
return result
# Create mapping of node IDs to be used to filter edges
node_names = [node["name"] for node in nodes]
# Add nodes to result
seen_nodes = set()
for node in nodes:
node_id = node["name"]
if node_id in seen_nodes:
continue
# Create node properties dictionary
properties = {
"entity_type": node["entity_type"],
"description": node["description"] or "",
"source_id": node["source_chunk_id"] or "",
}
# Add node to result
result.nodes.append(
KnowledgeGraphNode(
id=node_id, labels=[node["entity_type"]], properties=properties
)
)
seen_nodes.add(node_id)
# Get edges between these nodes
edges_sql = """
SELECT source_name, target_name, weight, keywords, description, source_chunk_id
FROM LIGHTRAG_GRAPH_EDGES
WHERE workspace = :workspace
AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
ORDER BY id
"""
edges_params = {"workspace": self.db.workspace, "node_names": node_names}
edges = await self.db.query(edges_sql, edges_params, multirows=True)
# Add edges to result
seen_edges = set()
for edge in edges:
source = edge["source_name"]
target = edge["target_name"]
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue
# Create edge properties dictionary
properties = {
"weight": edge["weight"] or 0.0,
"keywords": edge["keywords"] or "",
"description": edge["description"] or "",
"source_id": edge["source_chunk_id"] or "",
}
# Add edge to result
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="RELATED",
source=source,
target=target,
properties=properties,
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except Exception as e:
logger.error(f"Error retrieving knowledge graph: {e}")
return result
N_T = { N_T = {
@@ -927,4 +1259,12 @@ SQL_TEMPLATES = {
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
)""", )""",
# SQL for deletion
"delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})",
"delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name",
"delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)",
"delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph
MATCH (a)
WHERE a.workspace=:workspace AND a.name=:node_id
ACTION DELETE a)""",
} }

View File

@@ -7,7 +7,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import sys import sys
from tenacity import ( from tenacity import (
@@ -512,11 +512,103 @@ class PGVectorStorage(BaseVectorStorage):
# PG handles persistence automatically # PG handles persistence automatically
pass pass
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs from the storage.
Args:
ids: List of vector IDs to be deleted
"""
if not ids:
return
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = (
f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
)
try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError """Delete an entity by its name from the vector storage.
Args:
entity_name: The name of the entity to delete
"""
try:
# Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError """Delete all relations associated with an entity.
Args:
entity_name: The name of the entity whose relations should be deleted
"""
try:
# Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for prefix search: {self.namespace}")
return []
search_sql = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id LIKE $2"
params = {"workspace": self.db.workspace, "prefix": f"{prefix}%"}
try:
results = await self.db.query(search_sql, params, multirows=True)
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
# Format results to match the expected return format
formatted_results = []
for record in results:
formatted_record = dict(record)
# Ensure id field is available (for consistency with NanoVectorDB implementation)
if "id" not in formatted_record:
formatted_record["id"] = record["id"]
formatted_results.append(formatted_record)
return formatted_results
except Exception as e:
logger.error(f"Error during prefix search for '{prefix}': {e}")
return []
@final @final
@@ -1086,20 +1178,188 @@ class PGGraphStorage(BaseGraphStorage):
print("Implemented but never called.") print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError """
Delete a node from the graph.
Args:
node_id (str): The ID of the node to delete.
"""
label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, label)
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during node deletion: {%s}", e)
raise
async def remove_nodes(self, node_ids: list[str]) -> None:
"""
Remove multiple nodes from the graph.
Args:
node_ids (list[str]): A list of node IDs to remove.
"""
encoded_node_ids = [
self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
]
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
WHERE n.node_id IN [%s]
DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during node removal: {%s}", e)
raise
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
"""
Remove multiple edges from the graph.
Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
"""
encoded_edges = [
(
self._encode_graph_label(src.strip('"')),
self._encode_graph_label(tgt.strip('"')),
)
for src, tgt in edges
]
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity)-[r]->(b:Entity)
WHERE [a.node_id, b.node_id] IN [%s]
DELETE r
$$) AS (r agtype)""" % (self.graph_name, edge_list)
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during edge removal: {%s}", e)
raise
async def get_all_labels(self) -> list[str]:
"""
Get all labels (node IDs) in the graph.
Returns:
list[str]: A list of all labels in the graph.
"""
query = (
"""SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
$$) AS (label text)"""
% self.graph_name
)
results = await self._query(query)
labels = [self._decode_graph_label(result["label"]) for result in results]
return labels
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError """
Generate node embeddings using the specified algorithm.
async def get_all_labels(self) -> list[str]: Args:
raise NotImplementedError algorithm (str): The name of the embedding algorithm to use.
Returns:
tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs.
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
embed_func = self._node_embed_algorithms[algorithm]
return await embed_func()
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
raise NotImplementedError """
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
Args:
node_label (str): The label of the node to start from. If "*", the entire graph is returned.
max_depth (int): The maximum depth to traverse from the starting node.
Returns:
KnowledgeGraph: The retrieved subgraph.
"""
MAX_GRAPH_NODES = 1000
if node_label == "*":
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (
self.graph_name,
MAX_GRAPH_NODES,
)
else:
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
results = await self._query(query)
nodes = set()
edges = []
for result in results:
if node_label == "*":
if result["n"]:
node = result["n"]
nodes.add(self._decode_graph_label(node["node_id"]))
if result["m"]:
node = result["m"]
nodes.add(self._decode_graph_label(node["node_id"]))
if result["r"]:
edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
else:
if result["nodes"]:
for node in result["nodes"]:
nodes.add(self._decode_graph_label(node["node_id"]))
if result["relationships"]:
for edge in result["relationships"]:
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
kg = KnowledgeGraph(
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
)
return kg
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
import os import os
from typing import Any, final from typing import Any, final, List
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import hashlib import hashlib
@@ -141,8 +141,137 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Qdrant handles persistence automatically # Qdrant handles persistence automatically
pass pass
async def delete(self, ids: List[str]) -> None:
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
try:
# Convert regular ids to Qdrant compatible ids
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Delete points from the collection
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=qdrant_ids,
),
wait=True,
)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError """Delete an entity by name
Args:
entity_name: Name of the entity to delete
"""
try:
# Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity point from the collection
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=[entity_id],
),
wait=True,
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError """Delete all relations associated with an entity
Args:
entity_name: Name of the entity whose relations should be deleted
"""
try:
# Find relations where the entity is either source or target
results = self._client.scroll(
collection_name=self.namespace,
scroll_filter=models.Filter(
should=[
models.FieldCondition(
key="src_id", match=models.MatchValue(value=entity_name)
),
models.FieldCondition(
key="tgt_id", match=models.MatchValue(value=entity_name)
),
]
),
with_payload=True,
limit=1000, # Adjust as needed for your use case
)
# Extract points that need to be deleted
relation_points = results[0]
ids_to_delete = [point.id for point in relation_points]
if ids_to_delete:
# Delete the relations
self._client.delete(
collection_name=self.namespace,
points_selector=models.PointIdsList(
points=ids_to_delete,
),
wait=True,
)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use scroll method to find records with IDs starting with the prefix
results = self._client.scroll(
collection_name=self.namespace,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="id", match=models.MatchText(text=prefix, prefix=True)
)
]
),
with_payload=True,
with_vectors=False,
limit=1000, # Adjust as needed for your use case
)
# Extract matching points
matching_records = results[0]
# Format the results to match expected return format
formatted_results = [
{**point.payload, "id": point.id} for point in matching_records
]
logger.debug(
f"Found {len(formatted_results)} records with prefix '{prefix}'"
)
return formatted_results
except Exception as e:
logger.error(f"Error searching for prefix '{prefix}': {e}")
return []

View File

@@ -9,7 +9,7 @@ if not pm.is_installed("redis"):
# aioredis is a depricated library, replaced with redis # aioredis is a depricated library, replaced with redis
from redis.asyncio import Redis from redis.asyncio import Redis
from lightrag.utils import logger from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseKVStorage from lightrag.base import BaseKVStorage
import json import json
@@ -64,3 +64,86 @@ class RedisKVStorage(BaseKVStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Redis handles persistence automatically # Redis handles persistence automatically
pass pass
async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs
Args:
ids: List of entry IDs to be deleted
"""
if not ids:
return
pipe = self._redis.pipeline()
for id in ids:
pipe.delete(f"{self.namespace}:{id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name
Args:
entity_name: Name of the entity to delete
"""
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity
result = await self._redis.delete(f"{self.namespace}:{entity_id}")
if result:
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity
Args:
entity_name: Name of the entity whose relations should be deleted
"""
try:
# Get all keys in this namespace
cursor = 0
relation_keys = []
pattern = f"{self.namespace}:*"
while True:
cursor, keys = await self._redis.scan(cursor, match=pattern)
# For each key, get the value and check if it's related to entity_name
for key in keys:
value = await self._redis.get(key)
if value:
data = json.loads(value)
# Check if this is a relation involving the entity
if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key)
# Exit loop when cursor returns to 0
if cursor == 0:
break
# Delete the relation keys
if relation_keys:
deleted = await self._redis.delete(*relation_keys)
logger.debug(f"Deleted {deleted} relations for {entity_name}")
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")

View File

@@ -5,7 +5,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
@@ -414,6 +414,55 @@ class TiDBVectorDBStorage(BaseVectorStorage):
# Ti handles persistence automatically # Ti handles persistence automatically
pass pass
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
# Determine which table to query based on namespace
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
sql_template = """
SELECT entity_id as id, name as entity_name, entity_type, description, content
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
sql_template = """
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
keywords, description, content
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
sql_template = """
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
"""
else:
logger.warning(
f"Namespace {self.namespace} not supported for prefix search"
)
return []
# Add prefix pattern parameter with % for SQL LIKE
prefix_pattern = f"{prefix}%"
params = {"prefix_pattern": prefix_pattern, "workspace": self.db.workspace}
try:
results = await self.db.query(sql_template, params=params, multirows=True)
logger.debug(
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
)
return results if results else []
except Exception as e:
logger.error(f"Error searching records with prefix '{prefix}': {e}")
return []
@final @final
@dataclass @dataclass
@@ -566,15 +615,163 @@ class TiDBGraphStorage(BaseGraphStorage):
pass pass
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError """Delete a node and all its related edges
Args:
node_id: The ID of the node to delete
"""
# First delete all edges related to this node
await self.db.execute(
SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace},
)
# Then delete the node itself
await self.db.execute(
SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace},
)
logger.debug(
f"Node {node_id} and its related edges have been deleted from the graph"
)
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError """Get all entity types (labels) in the database
Returns:
List of labels sorted alphabetically
"""
result = await self.db.query(
SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace},
multirows=True,
)
if not result:
return []
# Extract all labels
return [item["label"] for item in result]
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
raise NotImplementedError """
Get a connected subgraph of nodes matching the specified label
Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000)
Args:
node_label: The node label to match
max_depth: Maximum depth of the subgraph
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Get matching nodes
if node_label == "*":
# Handle special case, get all nodes
node_results = await self.db.query(
SQL_TEMPLATES["get_all_nodes"],
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
multirows=True,
)
else:
# Get nodes matching the label
label_pattern = f"%{node_label}%"
node_results = await self.db.query(
SQL_TEMPLATES["get_matching_nodes"],
{"workspace": self.db.workspace, "label_pattern": label_pattern},
multirows=True,
)
if not node_results:
logger.warning(f"No nodes found matching label {node_label}")
return result
# Limit the number of returned nodes
if len(node_results) > MAX_GRAPH_NODES:
node_results = node_results[:MAX_GRAPH_NODES]
# Extract node names for edge query
node_names = [node["name"] for node in node_results]
node_names_str = ",".join([f"'{name}'" for name in node_names])
# Add nodes to result
for node in node_results:
node_properties = {
k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]
}
result.nodes.append(
KnowledgeGraphNode(
id=node["name"],
labels=[node["entity_type"]]
if node.get("entity_type")
else [node["name"]],
properties=node_properties,
)
)
# Get related edges
edge_results = await self.db.query(
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
{"workspace": self.db.workspace},
multirows=True,
)
if edge_results:
# Add edges to result
for edge in edge_results:
# Only include edges related to selected nodes
if (
edge["source_name"] in node_names
and edge["target_name"] in node_names
):
edge_id = f"{edge['source_name']}-{edge['target_name']}"
edge_properties = {
k: v
for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]
}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="RELATED",
source=edge["source_name"],
target=edge["target_name"],
properties=edge_properties,
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to delete
"""
for node_id in nodes:
await self.delete_node(node_id)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to delete, each edge is a (source, target) tuple
"""
for source, target in edges:
await self.db.execute(
SQL_TEMPLATES["remove_multiple_edges"],
{"source": source, "target": target, "workspace": self.db.workspace},
)
N_T = { N_T = {
@@ -785,4 +982,55 @@ SQL_TEMPLATES = {
weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description),
source_chunk_id = VALUES(source_chunk_id) source_chunk_id = VALUES(source_chunk_id)
""", """,
"delete_node": """
DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE name = :name AND workspace = :workspace
""",
"delete_node_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace
""",
"get_all_labels": """
SELECT DISTINCT entity_type as label
FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY entity_type
""",
"get_matching_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE name LIKE :label_pattern AND workspace = :workspace
ORDER BY name
""",
"get_all_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY name
LIMIT :max_nodes
""",
"get_related_edges": """
SELECT * FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name IN (:node_names) OR target_name IN (:node_names))
AND workspace = :workspace
""",
"remove_multiple_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace
""",
# Search by prefix SQL templates
"search_entity_by_prefix": """
SELECT entity_id as id, name as entity_name, entity_type, description, content
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
""",
"search_relationship_by_prefix": """
SELECT relation_id as id, source_name as src_id, target_name as tgt_id, keywords, description, content
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
""",
"search_chunk_by_prefix": """
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
""",
} }

View File

@@ -504,11 +504,39 @@ class LightRAG:
return text return text
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
return await self.chunk_entity_relation_graph.get_knowledge_graph( """Get knowledge graph for a given label
node_label=node_label, max_depth=max_depth
) Args:
node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges
"""
# get params supported by get_knowledge_graph of specified storage
import inspect
storage_params = inspect.signature(
self.chunk_entity_relation_graph.get_knowledge_graph
).parameters
kwargs = {"node_label": node_label, "max_depth": max_depth}
if "min_degree" in storage_params and min_degree > 0:
kwargs["min_degree"] = min_degree
if "inclusive" in storage_params:
kwargs["inclusive"] = inclusive
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name] import_path = STORAGES[storage_name]
@@ -845,7 +873,8 @@ class LightRAG:
) )
} }
# Process document (text chunks and full docs) in parallel # Process document (text chunks and full docs) in parallel
tasks = [ # Create tasks with references for potential cancellation
doc_status_task = asyncio.create_task(
self.doc_status.upsert( self.doc_status.upsert(
{ {
doc_id: { doc_id: {
@@ -857,13 +886,28 @@ class LightRAG:
"created_at": status_doc.created_at, "created_at": status_doc.created_at,
} }
} }
), )
self.chunks_vdb.upsert(chunks), )
self._process_entity_relation_graph(chunks), chunks_vdb_task = asyncio.create_task(
self.chunks_vdb.upsert(chunks)
)
entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(chunks)
)
full_docs_task = asyncio.create_task(
self.full_docs.upsert( self.full_docs.upsert(
{doc_id: {"content": status_doc.content}} {doc_id: {"content": status_doc.content}}
), )
self.text_chunks.upsert(chunks), )
text_chunks_task = asyncio.create_task(
self.text_chunks.upsert(chunks)
)
tasks = [
doc_status_task,
chunks_vdb_task,
entity_relation_task,
full_docs_task,
text_chunks_task,
] ]
try: try:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@@ -881,9 +925,25 @@ class LightRAG:
} }
) )
except Exception as e: except Exception as e:
logger.error( # Log error and update pipeline status
error_msg = (
f"Failed to process document {doc_id}: {str(e)}" f"Failed to process document {doc_id}: {str(e)}"
) )
logger.error(error_msg)
pipeline_status["latest_message"] = error_msg
pipeline_status["history_messages"].append(error_msg)
# Cancel other tasks as they are no longer meaningful
for task in [
chunks_vdb_task,
entity_relation_task,
full_docs_task,
text_chunks_task,
]:
if not task.done():
task.cancel()
# Update document status to failed
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
doc_id: { doc_id: {
@@ -926,7 +986,7 @@ class LightRAG:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
# 获取新的待处理文档 # Check for pending documents again
processing_docs, failed_docs, pending_docs = await asyncio.gather( processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING), self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED), self.doc_status.get_docs_by_status(DocStatus.FAILED),
@@ -1403,6 +1463,68 @@ class LightRAG:
] ]
) )
def delete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Synchronously delete a relation between two entities.
Args:
source_entity: Name of the source entity
target_entity: Name of the target entity
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.adelete_by_relation(source_entity, target_entity)
)
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Asynchronously delete a relation between two entities.
Args:
source_entity: Name of the source entity
target_entity: Name of the target entity
"""
try:
# Check if the relation exists
edge_exists = await self.chunk_entity_relation_graph.has_edge(
source_entity, target_entity
)
if not edge_exists:
logger.warning(
f"Relation from '{source_entity}' to '{target_entity}' does not exist"
)
return
# Delete relation from vector database
relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
await self.relationships_vdb.delete([relation_id])
# Delete relation from knowledge graph
await self.chunk_entity_relation_graph.remove_edges(
[(source_entity, target_entity)]
)
logger.info(
f"Successfully deleted relation from '{source_entity}' to '{target_entity}'"
)
await self._delete_relation_done()
except Exception as e:
logger.error(
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
)
async def _delete_relation_done(self) -> None:
"""Callback after relation deletion is complete"""
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.relationships_vdb,
self.chunk_entity_relation_graph,
]
]
)
def _get_content_summary(self, content: str, max_length: int = 100) -> str: def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content """Get summary of document content
@@ -1497,51 +1619,57 @@ class LightRAG:
await self.text_chunks.delete(chunk_ids) await self.text_chunks.delete(chunk_ids)
# 5. Find and process entities and relationships that have these chunks as source # 5. Find and process entities and relationships that have these chunks as source
# Get all nodes in the graph # Get all nodes and edges from the graph storage using storage-agnostic methods
nodes = self.chunk_entity_relation_graph._graph.nodes(data=True)
edges = self.chunk_entity_relation_graph._graph.edges(data=True)
# Track which entities and relationships need to be deleted or updated
entities_to_delete = set() entities_to_delete = set()
entities_to_update = {} # entity_name -> new_source_id entities_to_update = {} # entity_name -> new_source_id
relationships_to_delete = set() relationships_to_delete = set()
relationships_to_update = {} # (src, tgt) -> new_source_id relationships_to_update = {} # (src, tgt) -> new_source_id
# Process entities # Process entities - use storage-agnostic methods
for node, data in nodes: all_labels = await self.chunk_entity_relation_graph.get_all_labels()
if "source_id" in data: for node_label in all_labels:
node_data = await self.chunk_entity_relation_graph.get_node(node_label)
if node_data and "source_id" in node_data:
# Split source_id using GRAPH_FIELD_SEP # Split source_id using GRAPH_FIELD_SEP
sources = set(data["source_id"].split(GRAPH_FIELD_SEP)) sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
sources.difference_update(chunk_ids) sources.difference_update(chunk_ids)
if not sources: if not sources:
entities_to_delete.add(node) entities_to_delete.add(node_label)
logger.debug( logger.debug(
f"Entity {node} marked for deletion - no remaining sources" f"Entity {node_label} marked for deletion - no remaining sources"
) )
else: else:
new_source_id = GRAPH_FIELD_SEP.join(sources) new_source_id = GRAPH_FIELD_SEP.join(sources)
entities_to_update[node] = new_source_id entities_to_update[node_label] = new_source_id
logger.debug( logger.debug(
f"Entity {node} will be updated with new source_id: {new_source_id}" f"Entity {node_label} will be updated with new source_id: {new_source_id}"
) )
# Process relationships # Process relationships
for src, tgt, data in edges: for node_label in all_labels:
if "source_id" in data: node_edges = await self.chunk_entity_relation_graph.get_node_edges(
# Split source_id using GRAPH_FIELD_SEP node_label
sources = set(data["source_id"].split(GRAPH_FIELD_SEP)) )
sources.difference_update(chunk_ids) if node_edges:
if not sources: for src, tgt in node_edges:
relationships_to_delete.add((src, tgt)) edge_data = await self.chunk_entity_relation_graph.get_edge(
logger.debug( src, tgt
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
)
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
relationships_to_update[(src, tgt)] = new_source_id
logger.debug(
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
) )
if edge_data and "source_id" in edge_data:
# Split source_id using GRAPH_FIELD_SEP
sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
sources.difference_update(chunk_ids)
if not sources:
relationships_to_delete.add((src, tgt))
logger.debug(
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
)
else:
new_source_id = GRAPH_FIELD_SEP.join(sources)
relationships_to_update[(src, tgt)] = new_source_id
logger.debug(
f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}"
)
# Delete entities # Delete entities
if entities_to_delete: if entities_to_delete:
@@ -1555,12 +1683,15 @@ class LightRAG:
# Update entities # Update entities
for entity, new_source_id in entities_to_update.items(): for entity, new_source_id in entities_to_update.items():
node_data = self.chunk_entity_relation_graph._graph.nodes[entity] node_data = await self.chunk_entity_relation_graph.get_node(entity)
node_data["source_id"] = new_source_id if node_data:
await self.chunk_entity_relation_graph.upsert_node(entity, node_data) node_data["source_id"] = new_source_id
logger.debug( await self.chunk_entity_relation_graph.upsert_node(
f"Updated entity {entity} with new source_id: {new_source_id}" entity, node_data
) )
logger.debug(
f"Updated entity {entity} with new source_id: {new_source_id}"
)
# Delete relationships # Delete relationships
if relationships_to_delete: if relationships_to_delete:
@@ -1578,12 +1709,15 @@ class LightRAG:
# Update relationships # Update relationships
for (src, tgt), new_source_id in relationships_to_update.items(): for (src, tgt), new_source_id in relationships_to_update.items():
edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt] edge_data = await self.chunk_entity_relation_graph.get_edge(src, tgt)
edge_data["source_id"] = new_source_id if edge_data:
await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data) edge_data["source_id"] = new_source_id
logger.debug( await self.chunk_entity_relation_graph.upsert_edge(
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}" src, tgt, edge_data
) )
logger.debug(
f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}"
)
# 6. Delete original document and status # 6. Delete original document and status
await self.full_docs.delete([doc_id]) await self.full_docs.delete([doc_id])
@@ -1875,6 +2009,9 @@ class LightRAG:
new_entity_name, new_node_data new_entity_name, new_node_data
) )
# Store relationships that need to be updated
relations_to_update = []
# Get all edges related to the original entity # Get all edges related to the original entity
edges = await self.chunk_entity_relation_graph.get_node_edges( edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name entity_name
@@ -1890,10 +2027,16 @@ class LightRAG:
await self.chunk_entity_relation_graph.upsert_edge( await self.chunk_entity_relation_graph.upsert_edge(
new_entity_name, target, edge_data new_entity_name, target, edge_data
) )
relations_to_update.append(
(new_entity_name, target, edge_data)
)
else: # target == entity_name else: # target == entity_name
await self.chunk_entity_relation_graph.upsert_edge( await self.chunk_entity_relation_graph.upsert_edge(
source, new_entity_name, edge_data source, new_entity_name, edge_data
) )
relations_to_update.append(
(source, new_entity_name, edge_data)
)
# Delete old entity # Delete old entity
await self.chunk_entity_relation_graph.delete_node(entity_name) await self.chunk_entity_relation_graph.delete_node(entity_name)
@@ -1901,6 +2044,38 @@ class LightRAG:
# Delete old entity record from vector database # Delete old entity record from vector database
old_entity_id = compute_mdhash_id(entity_name, prefix="ent-") old_entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([old_entity_id]) await self.entities_vdb.delete([old_entity_id])
logger.info(
f"Deleted old entity '{entity_name}' and its vector embedding from database"
)
# Update relationship vector representations
for src, tgt, edge_data in relations_to_update:
description = edge_data.get("description", "")
keywords = edge_data.get("keywords", "")
source_id = edge_data.get("source_id", "")
weight = float(edge_data.get("weight", 1.0))
# Create new content for embedding
content = f"{src}\t{tgt}\n{keywords}\n{description}"
# Calculate relationship ID
relation_id = compute_mdhash_id(src + tgt, prefix="rel-")
# Prepare data for vector database update
relation_data = {
relation_id: {
"content": content,
"src_id": src,
"tgt_id": tgt,
"source_id": source_id,
"description": description,
"keywords": keywords,
"weight": weight,
}
}
# Update vector database
await self.relationships_vdb.upsert(relation_data)
# Update working entity name to new name # Update working entity name to new name
entity_name = new_entity_name entity_name = new_entity_name
@@ -1999,6 +2174,15 @@ class LightRAG:
f"Relation from '{source_entity}' to '{target_entity}' does not exist" f"Relation from '{source_entity}' to '{target_entity}' does not exist"
) )
# Important: First delete the old relation record from the vector database
old_relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
await self.relationships_vdb.delete([old_relation_id])
logger.info(
f"Deleted old relation record from vector database for relation {source_entity} -> {target_entity}"
)
# 2. Update relation information in the graph # 2. Update relation information in the graph
new_edge_data = {**edge_data, **updated_data} new_edge_data = {**edge_data, **updated_data}
await self.chunk_entity_relation_graph.upsert_edge( await self.chunk_entity_relation_graph.upsert_edge(
@@ -2012,7 +2196,7 @@ class LightRAG:
weight = float(new_edge_data.get("weight", 1.0)) weight = float(new_edge_data.get("weight", 1.0))
# Create content for embedding # Create content for embedding
content = f"{keywords}\t{source_entity}\n{target_entity}\n{description}" content = f"{source_entity}\t{target_entity}\n{keywords}\n{description}"
# Calculate relation ID # Calculate relation ID
relation_id = compute_mdhash_id( relation_id = compute_mdhash_id(
@@ -2276,3 +2460,409 @@ class LightRAG:
return loop.run_until_complete( return loop.run_until_complete(
self.acreate_relation(source_entity, target_entity, relation_data) self.acreate_relation(source_entity, target_entity, relation_data)
) )
async def amerge_entities(
self,
source_entities: list[str],
target_entity: str,
merge_strategy: dict[str, str] = None,
target_entity_data: dict[str, Any] = None,
) -> dict[str, Any]:
"""Asynchronously merge multiple entities into one entity.
Merges multiple source entities into a target entity, handling all relationships,
and updating both the knowledge graph and vector database.
Args:
source_entities: List of source entity names to merge
target_entity: Name of the target entity after merging
merge_strategy: Merge strategy configuration, e.g. {"description": "concatenate", "entity_type": "keep_first"}
Supported strategies:
- "concatenate": Concatenate all values (for text fields)
- "keep_first": Keep the first non-empty value
- "keep_last": Keep the last non-empty value
- "join_unique": Join all unique values (for fields separated by delimiter)
target_entity_data: Dictionary of specific values to set for the target entity,
overriding any merged values, e.g. {"description": "custom description", "entity_type": "PERSON"}
Returns:
Dictionary containing the merged entity information
"""
try:
# Default merge strategy
default_strategy = {
"description": "concatenate",
"entity_type": "keep_first",
"source_id": "join_unique",
}
merge_strategy = (
default_strategy
if merge_strategy is None
else {**default_strategy, **merge_strategy}
)
target_entity_data = (
{} if target_entity_data is None else target_entity_data
)
# 1. Check if all source entities exist
source_entities_data = {}
for entity_name in source_entities:
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
if not node_data:
raise ValueError(f"Source entity '{entity_name}' does not exist")
source_entities_data[entity_name] = node_data
# 2. Check if target entity exists and get its data if it does
target_exists = await self.chunk_entity_relation_graph.has_node(
target_entity
)
target_entity_data = {}
if target_exists:
target_entity_data = await self.chunk_entity_relation_graph.get_node(
target_entity
)
logger.info(
f"Target entity '{target_entity}' already exists, will merge data"
)
# 3. Merge entity data
merged_entity_data = self._merge_entity_attributes(
list(source_entities_data.values())
+ ([target_entity_data] if target_exists else []),
merge_strategy,
)
# Apply any explicitly provided target entity data (overrides merged data)
for key, value in target_entity_data.items():
merged_entity_data[key] = value
# 4. Get all relationships of the source entities
all_relations = []
for entity_name in source_entities:
# Get all relationships where this entity is the source
outgoing_edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name
)
if outgoing_edges:
for src, tgt in outgoing_edges:
# Ensure src is the current entity
if src == entity_name:
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
all_relations.append(("outgoing", src, tgt, edge_data))
# Get all relationships where this entity is the target
incoming_edges = []
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
for label in all_labels:
if label == entity_name:
continue
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
label
)
for src, tgt in node_edges or []:
if tgt == entity_name:
incoming_edges.append((src, tgt))
for src, tgt in incoming_edges:
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
all_relations.append(("incoming", src, tgt, edge_data))
# 5. Create or update the target entity
if not target_exists:
await self.chunk_entity_relation_graph.upsert_node(
target_entity, merged_entity_data
)
logger.info(f"Created new target entity '{target_entity}'")
else:
await self.chunk_entity_relation_graph.upsert_node(
target_entity, merged_entity_data
)
logger.info(f"Updated existing target entity '{target_entity}'")
# 6. Recreate all relationships, pointing to the target entity
relation_updates = {} # Track relationships that need to be merged
for rel_type, src, tgt, edge_data in all_relations:
new_src = target_entity if src in source_entities else src
new_tgt = target_entity if tgt in source_entities else tgt
# Skip relationships between source entities to avoid self-loops
if new_src == new_tgt:
logger.info(
f"Skipping relationship between source entities: {src} -> {tgt} to avoid self-loop"
)
continue
# Check if the same relationship already exists
relation_key = f"{new_src}|{new_tgt}"
if relation_key in relation_updates:
# Merge relationship data
existing_data = relation_updates[relation_key]["data"]
merged_relation = self._merge_relation_attributes(
[existing_data, edge_data],
{
"description": "concatenate",
"keywords": "join_unique",
"source_id": "join_unique",
"weight": "max",
},
)
relation_updates[relation_key]["data"] = merged_relation
logger.info(
f"Merged duplicate relationship: {new_src} -> {new_tgt}"
)
else:
relation_updates[relation_key] = {
"src": new_src,
"tgt": new_tgt,
"data": edge_data.copy(),
}
# Apply relationship updates
for rel_data in relation_updates.values():
await self.chunk_entity_relation_graph.upsert_edge(
rel_data["src"], rel_data["tgt"], rel_data["data"]
)
logger.info(
f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}"
)
# 7. Update entity vector representation
description = merged_entity_data.get("description", "")
source_id = merged_entity_data.get("source_id", "")
entity_type = merged_entity_data.get("entity_type", "")
content = target_entity + "\n" + description
entity_id = compute_mdhash_id(target_entity, prefix="ent-")
entity_data_for_vdb = {
entity_id: {
"content": content,
"entity_name": target_entity,
"source_id": source_id,
"description": description,
"entity_type": entity_type,
}
}
await self.entities_vdb.upsert(entity_data_for_vdb)
# 8. Update relationship vector representations
for rel_data in relation_updates.values():
src = rel_data["src"]
tgt = rel_data["tgt"]
edge_data = rel_data["data"]
description = edge_data.get("description", "")
keywords = edge_data.get("keywords", "")
source_id = edge_data.get("source_id", "")
weight = float(edge_data.get("weight", 1.0))
content = f"{keywords}\t{src}\n{tgt}\n{description}"
relation_id = compute_mdhash_id(src + tgt, prefix="rel-")
relation_data_for_vdb = {
relation_id: {
"content": content,
"src_id": src,
"tgt_id": tgt,
"source_id": source_id,
"description": description,
"keywords": keywords,
"weight": weight,
}
}
await self.relationships_vdb.upsert(relation_data_for_vdb)
# 9. Delete source entities
for entity_name in source_entities:
# Delete entity node from knowledge graph
await self.chunk_entity_relation_graph.delete_node(entity_name)
# Delete entity record from vector database
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([entity_id])
# Also ensure any relationships specific to this entity are deleted from vector DB
# This is a safety check, as these should have been transformed to the target entity already
entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-")
relations_with_entity = await self.relationships_vdb.search_by_prefix(
entity_relation_prefix
)
if relations_with_entity:
relation_ids = [r["id"] for r in relations_with_entity]
await self.relationships_vdb.delete(relation_ids)
logger.info(
f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database"
)
logger.info(
f"Deleted source entity '{entity_name}' and its vector embedding from database"
)
# 10. Save changes
await self._merge_entities_done()
logger.info(
f"Successfully merged {len(source_entities)} entities into '{target_entity}'"
)
return await self.get_entity_info(target_entity, include_vector_data=True)
except Exception as e:
logger.error(f"Error merging entities: {e}")
raise
def merge_entities(
self,
source_entities: list[str],
target_entity: str,
merge_strategy: dict[str, str] = None,
target_entity_data: dict[str, Any] = None,
) -> dict[str, Any]:
"""Synchronously merge multiple entities into one entity.
Merges multiple source entities into a target entity, handling all relationships,
and updating both the knowledge graph and vector database.
Args:
source_entities: List of source entity names to merge
target_entity: Name of the target entity after merging
merge_strategy: Merge strategy configuration, e.g. {"description": "concatenate", "entity_type": "keep_first"}
target_entity_data: Dictionary of specific values to set for the target entity,
overriding any merged values, e.g. {"description": "custom description", "entity_type": "PERSON"}
Returns:
Dictionary containing the merged entity information
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.amerge_entities(
source_entities, target_entity, merge_strategy, target_entity_data
)
)
def _merge_entity_attributes(
self, entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]:
"""Merge attributes from multiple entities.
Args:
entity_data_list: List of dictionaries containing entity data
merge_strategy: Merge strategy for each field
Returns:
Dictionary containing merged entity data
"""
merged_data = {}
# Collect all possible keys
all_keys = set()
for data in entity_data_list:
all_keys.update(data.keys())
# Merge values for each key
for key in all_keys:
# Get all values for this key
values = [data.get(key) for data in entity_data_list if data.get(key)]
if not values:
continue
# Merge values according to strategy
strategy = merge_strategy.get(key, "keep_first")
if strategy == "concatenate":
merged_data[key] = "\n\n".join(values)
elif strategy == "keep_first":
merged_data[key] = values[0]
elif strategy == "keep_last":
merged_data[key] = values[-1]
elif strategy == "join_unique":
# Handle fields separated by GRAPH_FIELD_SEP
unique_items = set()
for value in values:
items = value.split(GRAPH_FIELD_SEP)
unique_items.update(items)
merged_data[key] = GRAPH_FIELD_SEP.join(unique_items)
else:
# Default strategy
merged_data[key] = values[0]
return merged_data
def _merge_relation_attributes(
self, relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]:
"""Merge attributes from multiple relationships.
Args:
relation_data_list: List of dictionaries containing relationship data
merge_strategy: Merge strategy for each field
Returns:
Dictionary containing merged relationship data
"""
merged_data = {}
# Collect all possible keys
all_keys = set()
for data in relation_data_list:
all_keys.update(data.keys())
# Merge values for each key
for key in all_keys:
# Get all values for this key
values = [
data.get(key)
for data in relation_data_list
if data.get(key) is not None
]
if not values:
continue
# Merge values according to strategy
strategy = merge_strategy.get(key, "keep_first")
if strategy == "concatenate":
merged_data[key] = "\n\n".join(str(v) for v in values)
elif strategy == "keep_first":
merged_data[key] = values[0]
elif strategy == "keep_last":
merged_data[key] = values[-1]
elif strategy == "join_unique":
# Handle fields separated by GRAPH_FIELD_SEP
unique_items = set()
for value in values:
items = str(value).split(GRAPH_FIELD_SEP)
unique_items.update(items)
merged_data[key] = GRAPH_FIELD_SEP.join(unique_items)
elif strategy == "max":
# For numeric fields like weight
try:
merged_data[key] = max(float(v) for v in values)
except (ValueError, TypeError):
merged_data[key] = values[0]
else:
# Default strategy
merged_data[key] = values[0]
return merged_data
async def _merge_entities_done(self) -> None:
"""Callback after entity merging is complete, ensures updates are persisted"""
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]
]
)

View File

@@ -1242,9 +1242,11 @@ async def _find_most_related_text_unit_from_entities(
all_text_units_lookup = {} all_text_units_lookup = {}
tasks = [] tasks = []
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units: for c_id in this_text_units:
if c_id not in all_text_units_lookup: if c_id not in all_text_units_lookup:
all_text_units_lookup[c_id] = index
tasks.append((c_id, index, this_edges)) tasks.append((c_id, index, this_edges))
results = await asyncio.gather( results = await asyncio.gather(

View File

@@ -161,8 +161,12 @@ axiosInstance.interceptors.response.use(
) )
// API methods // API methods
export const queryGraphs = async (label: string, maxDepth: number): Promise<LightragGraphType> => { export const queryGraphs = async (
const response = await axiosInstance.get(`/graphs?label=${label}&max_depth=${maxDepth}`) label: string,
maxDepth: number,
minDegree: number
): Promise<LightragGraphType> => {
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
return response.data return response.data
} }

View File

@@ -40,18 +40,21 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
const focusedEdge = useGraphStore.use.focusedEdge() const focusedEdge = useGraphStore.use.focusedEdge()
/** /**
* When component mount * When component mount or maxIterations changes
* => load the graph * => load the graph and apply layout
*/ */
useEffect(() => { useEffect(() => {
// Create & load the graph // Create & load the graph
const graph = lightrageGraph() const graph = lightrageGraph()
loadGraph(graph) loadGraph(graph)
if (!(graph as any).__force_applied) { assignLayout()
assignLayout() }, [assignLayout, loadGraph, lightrageGraph, maxIterations])
Object.assign(graph, { __force_applied: true })
}
/**
* When component mount
* => register events
*/
useEffect(() => {
const { setFocusedNode, setSelectedNode, setFocusedEdge, setSelectedEdge, clearSelection } = const { setFocusedNode, setSelectedNode, setFocusedEdge, setSelectedEdge, clearSelection } =
useGraphStore.getState() useGraphStore.getState()
@@ -87,7 +90,7 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
}, },
clickStage: () => clearSelection() clickStage: () => clearSelection()
}) })
}, [assignLayout, loadGraph, registerEvents, lightrageGraph]) }, [registerEvents])
/** /**
* When component mount or hovered node change * When component mount or hovered node change

View File

@@ -91,9 +91,12 @@ const LabeledNumberInput = ({
{label} {label}
</label> </label>
<Input <Input
value={currentValue || ''} type="number"
value={currentValue === null ? '' : currentValue}
onChange={onValueChange} onChange={onValueChange}
className="h-6 w-full min-w-0" className="h-6 w-full min-w-0 pr-1"
min={min}
max={max}
onBlur={onBlur} onBlur={onBlur}
onKeyDown={(e) => { onKeyDown={(e) => {
if (e.key === 'Enter') { if (e.key === 'Enter') {
@@ -120,6 +123,7 @@ export default function Settings() {
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges() const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
const showEdgeLabel = useSettingsStore.use.showEdgeLabel() const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth() const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
const graphMinDegree = useSettingsStore.use.graphMinDegree()
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations() const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
const enableHealthCheck = useSettingsStore.use.enableHealthCheck() const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
@@ -178,6 +182,11 @@ export default function Settings() {
useSettingsStore.setState({ graphQueryMaxDepth: depth }) useSettingsStore.setState({ graphQueryMaxDepth: depth })
}, []) }, [])
const setGraphMinDegree = useCallback((degree: number) => {
if (degree < 0) return
useSettingsStore.setState({ graphMinDegree: degree })
}, [])
const setGraphLayoutMaxIterations = useCallback((iterations: number) => { const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
if (iterations < 1) return if (iterations < 1) return
useSettingsStore.setState({ graphLayoutMaxIterations: iterations }) useSettingsStore.setState({ graphLayoutMaxIterations: iterations })
@@ -274,7 +283,7 @@ export default function Settings() {
min={0} min={0}
value={graphMinDegree} value={graphMinDegree}
onEditFinished={setGraphMinDegree} onEditFinished={setGraphMinDegree}
/> />
<LabeledNumberInput <LabeledNumberInput
label={t("graphPanel.sideBar.settings.maxLayoutIterations")} label={t("graphPanel.sideBar.settings.maxLayoutIterations")}
min={1} min={1}
@@ -282,7 +291,6 @@ export default function Settings() {
value={graphLayoutMaxIterations} value={graphLayoutMaxIterations}
onEditFinished={setGraphLayoutMaxIterations} onEditFinished={setGraphLayoutMaxIterations}
/> />
<Separator /> <Separator />
<div className="flex flex-col gap-2"> <div className="flex flex-col gap-2">

View File

@@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<'input'>>(
<input <input
type={type} type={type}
className={cn( className={cn(
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm', 'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-100 [&::-webkit-outer-spin-button]:opacity-100',
className className
)} )}
ref={ref} ref={ref}

View File

@@ -50,11 +50,11 @@ export type NodeType = {
} }
export type EdgeType = { label: string } export type EdgeType = { label: string }
const fetchGraph = async (label: string, maxDepth: number) => { const fetchGraph = async (label: string, maxDepth: number, minDegree: number) => {
let rawData: any = null let rawData: any = null
try { try {
rawData = await queryGraphs(label, maxDepth) rawData = await queryGraphs(label, maxDepth, minDegree)
} catch (e) { } catch (e) {
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!') useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!')
return null return null
@@ -161,13 +161,14 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
return graph return graph
} }
const lastQueryLabel = { label: '', maxQueryDepth: 0 } const lastQueryLabel = { label: '', maxQueryDepth: 0, minDegree: 0 }
const useLightrangeGraph = () => { const useLightrangeGraph = () => {
const queryLabel = useSettingsStore.use.queryLabel() const queryLabel = useSettingsStore.use.queryLabel()
const rawGraph = useGraphStore.use.rawGraph() const rawGraph = useGraphStore.use.rawGraph()
const sigmaGraph = useGraphStore.use.sigmaGraph() const sigmaGraph = useGraphStore.use.sigmaGraph()
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth() const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
const minDegree = useSettingsStore.use.graphMinDegree()
const getNode = useCallback( const getNode = useCallback(
(nodeId: string) => { (nodeId: string) => {
@@ -185,13 +186,16 @@ const useLightrangeGraph = () => {
useEffect(() => { useEffect(() => {
if (queryLabel) { if (queryLabel) {
if (lastQueryLabel.label !== queryLabel || lastQueryLabel.maxQueryDepth !== maxQueryDepth) { if (lastQueryLabel.label !== queryLabel ||
lastQueryLabel.maxQueryDepth !== maxQueryDepth ||
lastQueryLabel.minDegree !== minDegree) {
lastQueryLabel.label = queryLabel lastQueryLabel.label = queryLabel
lastQueryLabel.maxQueryDepth = maxQueryDepth lastQueryLabel.maxQueryDepth = maxQueryDepth
lastQueryLabel.minDegree = minDegree
const state = useGraphStore.getState() const state = useGraphStore.getState()
state.reset() state.reset()
fetchGraph(queryLabel, maxQueryDepth).then((data) => { fetchGraph(queryLabel, maxQueryDepth, minDegree).then((data) => {
// console.debug('Query label: ' + queryLabel) // console.debug('Query label: ' + queryLabel)
state.setSigmaGraph(createSigmaGraph(data)) state.setSigmaGraph(createSigmaGraph(data))
data?.buildDynamicMap() data?.buildDynamicMap()
@@ -203,7 +207,7 @@ const useLightrangeGraph = () => {
state.reset() state.reset()
state.setSigmaGraph(new DirectedGraph()) state.setSigmaGraph(new DirectedGraph())
} }
}, [queryLabel, maxQueryDepth]) }, [queryLabel, maxQueryDepth, minDegree])
const lightrageGraph = useCallback(() => { const lightrageGraph = useCallback(() => {
if (sigmaGraph) { if (sigmaGraph) {

View File

@@ -22,6 +22,9 @@ interface SettingsState {
graphQueryMaxDepth: number graphQueryMaxDepth: number
setGraphQueryMaxDepth: (depth: number) => void setGraphQueryMaxDepth: (depth: number) => void
graphMinDegree: number
setGraphMinDegree: (degree: number) => void
graphLayoutMaxIterations: number graphLayoutMaxIterations: number
setGraphLayoutMaxIterations: (iterations: number) => void setGraphLayoutMaxIterations: (iterations: number) => void
@@ -66,6 +69,7 @@ const useSettingsStoreBase = create<SettingsState>()(
enableEdgeEvents: false, enableEdgeEvents: false,
graphQueryMaxDepth: 3, graphQueryMaxDepth: 3,
graphMinDegree: 0,
graphLayoutMaxIterations: 10, graphLayoutMaxIterations: 10,
queryLabel: defaultQueryLabel, queryLabel: defaultQueryLabel,
@@ -107,6 +111,8 @@ const useSettingsStoreBase = create<SettingsState>()(
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }), setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
setGraphMinDegree: (degree: number) => set({ graphMinDegree: degree }),
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }), setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
setApiKey: (apiKey: string | null) => set({ apiKey }), setApiKey: (apiKey: string | null) => set({ apiKey }),

View File

@@ -1 +1,11 @@
/// <reference types="vite/client" /> /// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_API_PROXY: string
readonly VITE_API_ENDPOINTS: string
readonly VITE_BACKEND_URL: string
}
interface ImportMeta {
readonly env: ImportMetaEnv
}

View File

@@ -26,5 +26,5 @@
"@/*": ["./src/*"] "@/*": ["./src/*"]
} }
}, },
"include": ["src"] "include": ["src", "vite.config.ts"]
} }

View File

@@ -14,6 +14,21 @@ export default defineConfig({
}, },
base: './', base: './',
build: { build: {
outDir: path.resolve(__dirname, '../lightrag/api/webui') outDir: path.resolve(__dirname, '../lightrag/api/webui'),
emptyOutDir: true
},
server: {
proxy: import.meta.env.VITE_API_PROXY === 'true' && import.meta.env.VITE_API_ENDPOINTS ?
Object.fromEntries(
import.meta.env.VITE_API_ENDPOINTS.split(',').map(endpoint => [
endpoint,
{
target: import.meta.env.VITE_BACKEND_URL || 'http://localhost:9621',
changeOrigin: true,
rewrite: endpoint === '/api' ?
(path) => path.replace(/^\/api/, '') : undefined
}
])
) : {}
} }
}) })