Merge branch 'main' into main

This commit is contained in:
zrguo
2025-03-09 00:23:06 +08:00
committed by GitHub
36 changed files with 884 additions and 191 deletions

View File

@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
### Redis ### Redis
REDIS_URI=redis://localhost:6379 REDIS_URI=redis://localhost:6379
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/login,/health # white list

51
examples/test_postgres.py Normal file
View File

@@ -0,0 +1,51 @@
import os
import asyncio
from lightrag.kg.postgres_impl import PGGraphStorage
from lightrag.llm.ollama import ollama_embedding
from lightrag.utils import EmbeddingFunc
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./local_neo4jWorkDir"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# AGE
os.environ["AGE_GRAPH_NAME"] = "dickens"
os.environ["POSTGRES_HOST"] = "localhost"
os.environ["POSTGRES_PORT"] = "15432"
os.environ["POSTGRES_USER"] = "rag"
os.environ["POSTGRES_PASSWORD"] = "rag"
os.environ["POSTGRES_DATABASE"] = "rag"
async def main():
graph_db = PGGraphStorage(
namespace="dickens",
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts, embed_model="bge-m3", host="http://localhost:11434"
),
),
global_config={},
)
await graph_db.initialize()
labels = await graph_db.get_all_labels()
print("all labels", labels)
res = await graph_db.get_knowledge_graph("FEZZIWIG")
print("knowledge graphs", res)
await graph_db.finalize()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
pip install lightrag-hku pip install lightrag-hku
``` ```
## Authentication Endpoints
### JWT Authentication Mechanism
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
```bash
# For jwt auth
AUTH_USERNAME=admin # login name
AUTH_PASSWORD=admin123 # password
TOKEN_SECRET=your-key # JWT key
TOKEN_EXPIRE_HOURS=4 # expire duration
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
```
## API Endpoints ## API Endpoints
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit: All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:

41
lightrag/api/auth.py Normal file
View File

@@ -0,0 +1,41 @@
import os
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException, status
from pydantic import BaseModel
class TokenPayload(BaseModel):
sub: str
exp: datetime
class AuthHandler:
def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
def create_token(self, username: str) -> str:
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
payload = TokenPayload(sub=username, exp=expire)
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
def validate_token(self, token: str) -> str:
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
expire_timestamp = payload["exp"]
expire_time = datetime.utcfromtimestamp(expire_timestamp)
if datetime.utcnow() > expire_time:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
return payload["sub"]
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
auth_handler = AuthHandler()

View File

@@ -2,10 +2,7 @@
LightRAG FastAPI Server LightRAG FastAPI Server
""" """
from fastapi import ( from fastapi import FastAPI, Depends, HTTPException, status
FastAPI,
Depends,
)
import asyncio import asyncio
import os import os
import logging import logging
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
initialize_pipeline_status, initialize_pipeline_status,
get_all_update_flags_status, get_all_update_flags_status,
) )
from fastapi.security import OAuth2PasswordRequestForm
from .auth import auth_handler
# Load environment variables # Load environment variables
# Updated to use the .env that is inside the current folder # Updated to use the .env that is inside the current folder
@@ -372,6 +371,27 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k) ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured",
)
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
return {
"access_token": auth_handler.create_token(username),
"token_type": "bearer",
}
@app.get("/health", dependencies=[Depends(optional_api_key)]) @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""

View File

@@ -1,10 +1,20 @@
aiofiles aiofiles
ascii_colors ascii_colors
asyncpg
distro
fastapi fastapi
httpcore
httpx
jiter
numpy numpy
openai
passlib[bcrypt]
pipmaster pipmaster
PyJWT
python-dotenv python-dotenv
python-jose[cryptography]
python-multipart python-multipart
pytz
tenacity tenacity
tiktoken tiktoken
uvicorn uvicorn

View File

@@ -18,8 +18,11 @@ from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.api.utils_api import get_api_key_dependency, global_args from lightrag.api.utils_api import get_api_key_dependency, global_args
router = APIRouter(
router = APIRouter(prefix="/documents", tags=["documents"]) prefix="/documents",
tags=["documents"],
dependencies=[Depends(get_auth_dependency())],
)
# Temporary file prefix # Temporary file prefix
temp_prefix = "__tmp__" temp_prefix = "__tmp__"

View File

@@ -3,12 +3,11 @@ This module contains all graph-related routes for the LightRAG API.
""" """
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
router = APIRouter(tags=["graph"]) router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
def create_graph_routes(rag, api_key: Optional[str] = None): def create_graph_routes(rag, api_key: Optional[str] = None):
@@ -25,23 +24,33 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
return await rag.get_graph_labels() return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)]) @router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3): async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
):
""" """
Retrieve a connected subgraph of nodes where the label includes the specified label. Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence 1. min_degree does not affect nodes directly connected to the matching nodes
2. Followed by nodes directly connected to the matching nodes 2. Label matching nodes take precedence
3. Finally, the degree of the nodes 3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args: Args:
label (str): Label to get knowledge graph for label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3. max_depth (int, optional): Maximum depth of graph. Defaults to 3.
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
Returns: Returns:
Dict[str, List[str]]: Knowledge graph for label Dict[str, List[str]]: Knowledge graph for label
""" """
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth) return await rag.get_knowledge_graph(
node_label=label,
max_depth=max_depth,
inclusive=inclusive,
min_degree=min_degree,
)
return router return router

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam from lightrag.base import QueryParam
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency, get_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception from ascii_colors import trace_exception
router = APIRouter(tags=["query"]) router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
class QueryRequest(BaseModel): class QueryRequest(BaseModel):

View File

@@ -9,10 +9,11 @@ import sys
import logging import logging
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from fastapi import HTTPException, Security from fastapi import HTTPException, Security, Depends, Request
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
# Load environment variables # Load environment variables
load_dotenv(override=True) load_dotenv(override=True)
@@ -33,6 +34,24 @@ class OllamaServerInfos:
ollama_server_infos = OllamaServerInfos() ollama_server_infos = OllamaServerInfos()
def get_auth_dependency():
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
async def dependency(
request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
):
if request.url.path in whitelist:
return
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
return
auth_handler.validate_token(token)
return dependency
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
""" """
Create an API key dependency for route protection. Create an API key dependency for route protection.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -5,8 +5,8 @@
<link rel="icon" type="image/svg+xml" href="./logo.png" /> <link rel="icon" type="image/svg+xml" href="./logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title> <title>Lightrag</title>
<script type="module" crossorigin src="./assets/index-DbuMPJAD.js"></script> <script type="module" crossorigin src="./assets/index-CJz72b6Q.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-rP-YlyR1.css"> <link rel="stylesheet" crossorigin href="./assets/index-CH-3l4_Z.css">
</head> </head>
<body> <body>
<div id="root"></div> <div id="root"></div>

View File

@@ -204,7 +204,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 3
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node.""" """Retrieve a subgraph of the knowledge graph starting from a given node."""

View File

@@ -229,3 +229,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise raise
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Get all records from the collection
# Since ChromaDB doesn't directly support prefix search on IDs,
# we'll get all records and filter in Python
results = self._collection.get(
include=["metadatas", "documents", "embeddings"]
)
matching_records = []
# Filter records where ID starts with the prefix
for i, record_id in enumerate(results["ids"]):
if record_id.startswith(prefix):
matching_records.append(
{
"id": record_id,
"content": results["documents"][i],
"vector": results["embeddings"][i],
**results["metadatas"][i],
}
)
logger.debug(
f"Found {len(matching_records)} records with prefix '{prefix}'"
)
return matching_records
except Exception as e:
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
raise

View File

@@ -371,3 +371,24 @@ class FaissVectorDBStorage(BaseVectorStorage):
return False # Return error return False # Return error
return True # Return success return True # Return success
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
matching_records = []
# Search for records with IDs starting with the prefix
for faiss_id, meta in self._id_to_meta.items():
if "__id__" in meta and meta["__id__"].startswith(prefix):
# Create a copy of all metadata and add "id" field
record = {**meta, "id": meta["__id__"]}
matching_records.append(record)
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
return matching_records

View File

@@ -206,3 +206,28 @@ class MilvusVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use Milvus query with expression to find IDs with the given prefix
expression = f'id like "{prefix}%"'
results = self._client.query(
collection_name=self.namespace,
filter=expression,
output_fields=list(self.meta_fields) + ["id"],
)
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
return results
except Exception as e:
logger.error(f"Error searching for records with prefix '{prefix}': {e}")
return []

View File

@@ -1045,6 +1045,32 @@ class MongoVectorDBStorage(BaseVectorStorage):
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error deleting relations for {entity_name}: {str(e)}") logger.error(f"Error deleting relations for {entity_name}: {str(e)}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use MongoDB regex to find documents where _id starts with the prefix
cursor = self._data.find({"_id": {"$regex": f"^{prefix}"}})
matching_records = await cursor.to_list(length=None)
# Format results
results = [{**doc, "id": doc["_id"]} for doc in matching_records]
logger.debug(
f"Found {len(results)} records with prefix '{prefix}' in {self.namespace}"
)
return results
except PyMongoError as e:
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
return []
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
collection_names = await db.list_collection_names() collection_names = await db.list_collection_names()

View File

@@ -236,3 +236,23 @@ class NanoVectorDBStorage(BaseVectorStorage):
return False # Return error return False # Return error
return True # Return success return True # Return success
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
storage = await self.client_storage
matching_records = []
# Search for records with IDs starting with the prefix
for record in storage["data"]:
if "__id__" in record and record["__id__"].startswith(prefix):
matching_records.append({**record, "id": record["__id__"]})
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
return matching_records

View File

@@ -232,19 +232,26 @@ class NetworkXStorage(BaseGraphStorage):
return sorted(list(labels)) return sorted(list(labels))
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence 1. min_degree does not affect nodes directly connected to the matching nodes
2. Followed by nodes directly connected to the matching nodes 2. Label matching nodes take precedence
3. Finally, the degree of the nodes 3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges
@@ -255,6 +262,10 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph() graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # For "*", return the entire graph including all nodes and edges
@@ -262,10 +273,15 @@ class NetworkXStorage(BaseGraphStorage):
graph.copy() graph.copy()
) # Create a copy to avoid modifying the original graph ) # Create a copy to avoid modifying the original graph
else: else:
# Find nodes with matching node id (partial match) # Find nodes with matching node id based on search_mode
nodes_to_explore = [] nodes_to_explore = []
for n, attr in graph.nodes(data=True): for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching node_str = str(n)
if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n) nodes_to_explore.append(n)
if not nodes_to_explore: if not nodes_to_explore:
@@ -277,26 +293,37 @@ class NetworkXStorage(BaseGraphStorage):
for start_node in nodes_to_explore: for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph) combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Get start nodes and direct connected nodes
if nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(
combined_subgraph.neighbors(start_node)
)
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
subgraph = combined_subgraph subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES: if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
def priority_key(node_item): def priority_key(node_item):
node, degree = node_item node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0) # Priority order: start(2) > directly connected(1) > other nodes(0)
@@ -356,7 +383,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type="RELATED", type="DIRECTED",
source=str(source), source=str(source),
target=str(target), target=str(target),
properties=edge_data, properties=edge_data,

View File

@@ -494,6 +494,41 @@ class OracleVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting relations for entity {entity_name}: {e}") logger.error(f"Error deleting relations for entity {entity_name}: {e}")
raise raise
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Determine the appropriate table based on namespace
table_name = namespace_to_table_name(self.namespace)
# Create SQL query to find records with IDs starting with prefix
search_sql = f"""
SELECT * FROM {table_name}
WHERE workspace = :workspace
AND id LIKE :prefix_pattern
ORDER BY id
"""
params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
# Execute query and get results
results = await self.db.query(search_sql, params, multirows=True)
logger.debug(
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
)
return results or []
except Exception as e:
logger.error(f"Error searching records with prefix '{prefix}': {e}")
return []
@final @final
@dataclass @dataclass

View File

@@ -575,6 +575,41 @@ class PGVectorStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}") logger.error(f"Error deleting relations for entity {entity_name}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for prefix search: {self.namespace}")
return []
search_sql = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id LIKE $2"
params = {"workspace": self.db.workspace, "prefix": f"{prefix}%"}
try:
results = await self.db.query(search_sql, params, multirows=True)
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
# Format results to match the expected return format
formatted_results = []
for record in results:
formatted_record = dict(record)
# Ensure id field is available (for consistency with NanoVectorDB implementation)
if "id" not in formatted_record:
formatted_record["id"] = record["id"]
formatted_results.append(formatted_record)
return formatted_results
except Exception as e:
logger.error(f"Error during prefix search for '{prefix}': {e}")
return []
@final @final
@dataclass @dataclass
@@ -775,6 +810,14 @@ class PGGraphStorage(BaseGraphStorage):
v = record[k] v = record[k]
# agtype comes back '{key: value}::type' which must be parsed # agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v: if isinstance(v, str) and "::" in v:
if v.startswith("[") and v.endswith("]"):
if "::vertex" not in v:
continue
v = v.replace("::vertex", "")
vertexes = json.loads(v)
for vertex in vertexes:
vertices[vertex["id"]] = vertex.get("properties")
else:
dtype = v.split("::")[-1] dtype = v.split("::")[-1]
v = v.split("::")[0] v = v.split("::")[0]
if dtype == "vertex": if dtype == "vertex":
@@ -785,17 +828,49 @@ class PGGraphStorage(BaseGraphStorage):
for k in record.keys(): for k in record.keys():
v = record[k] v = record[k]
if isinstance(v, str) and "::" in v: if isinstance(v, str) and "::" in v:
if v.startswith("[") and v.endswith("]"):
if "::vertex" in v:
v = v.replace("::vertex", "")
vertexes = json.loads(v)
dl = []
for vertex in vertexes:
prop = vertex.get("properties")
if not prop:
prop = {}
prop["label"] = PGGraphStorage._decode_graph_label(
prop["node_id"]
)
dl.append(prop)
d[k] = dl
elif "::edge" in v:
v = v.replace("::edge", "")
edges = json.loads(v)
dl = []
for edge in edges:
dl.append(
(
vertices[edge["start_id"]],
edge["label"],
vertices[edge["end_id"]],
)
)
d[k] = dl
else:
print("WARNING: unsupported type")
continue
else:
dtype = v.split("::")[-1] dtype = v.split("::")[-1]
v = v.split("::")[0] v = v.split("::")[0]
else:
dtype = ""
if dtype == "vertex": if dtype == "vertex":
vertex = json.loads(v) vertex = json.loads(v)
field = vertex.get("properties") field = vertex.get("properties")
if not field: if not field:
field = {} field = {}
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"]) field["label"] = PGGraphStorage._decode_graph_label(
field["node_id"]
)
d[k] = field d[k] = field
# convert edge from id-label->id by replacing id with node information # convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query # we only do this if the vertex was also returned in the query
@@ -809,6 +884,9 @@ class PGGraphStorage(BaseGraphStorage):
], # we don't use decode_graph_label(), since edge label is always "DIRECTED" ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}), vertices.get(edge["end_id"], {}),
) )
else:
if v is None or (v.count("{") < 1 and v.count("[") < 1):
d[k] = v
else: else:
d[k] = json.loads(v) if isinstance(v, str) else v d[k] = json.loads(v) if isinstance(v, str) else v
@@ -1284,7 +1362,7 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m) OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % ( $$) AS (nodes agtype, relationships agtype)""" % (
self.graph_name, self.graph_name,
encoded_node_label, encoded_node_label,
max_depth, max_depth,
@@ -1293,17 +1371,23 @@ class PGGraphStorage(BaseGraphStorage):
results = await self._query(query) results = await self._query(query)
nodes = set() nodes = {}
edges = [] edges = []
unique_edge_ids = set()
for result in results: for result in results:
if node_label == "*": if node_label == "*":
if result["n"]: if result["n"]:
node = result["n"] node = result["n"]
nodes.add(self._decode_graph_label(node["node_id"])) node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["m"]: if result["m"]:
node = result["m"] node = result["m"]
nodes.add(self._decode_graph_label(node["node_id"])) node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["r"]: if result["r"]:
edge = result["r"] edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"]) src_id = self._decode_graph_label(edge["start_id"])
@@ -1312,16 +1396,36 @@ class PGGraphStorage(BaseGraphStorage):
else: else:
if result["nodes"]: if result["nodes"]:
for node in result["nodes"]: for node in result["nodes"]:
nodes.add(self._decode_graph_label(node["node_id"])) node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["relationships"]: if result["relationships"]:
for edge in result["relationships"]: for edge in result["relationships"]: # src --DIRECTED--> target
src_id = self._decode_graph_label(edge["start_id"]) src_id = self._decode_graph_label(edge[0]["node_id"])
tgt_id = self._decode_graph_label(edge["end_id"]) tgt_id = self._decode_graph_label(edge[2]["node_id"])
edges.append((src_id, tgt_id)) id = src_id + "," + tgt_id
if id in unique_edge_ids:
continue
else:
unique_edge_ids.add(id)
edges.append(
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
)
kg = KnowledgeGraph( kg = KnowledgeGraph(
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes], nodes=[
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges], KnowledgeGraphNode(
id=node_id, labels=[node_id], properties=nodes[node_id]
)
for node_id in nodes
],
edges=[
KnowledgeGraphEdge(
id=id, type="DIRECTED", source=src, target=tgt, properties=props
)
for id, src, tgt, props in edges
],
) )
return kg return kg

View File

@@ -233,3 +233,45 @@ class QdrantVectorDBStorage(BaseVectorStorage):
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use scroll method to find records with IDs starting with the prefix
results = self._client.scroll(
collection_name=self.namespace,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="id", match=models.MatchText(text=prefix, prefix=True)
)
]
),
with_payload=True,
with_vectors=False,
limit=1000, # Adjust as needed for your use case
)
# Extract matching points
matching_records = results[0]
# Format the results to match expected return format
formatted_results = [
{**point.payload, "id": point.id} for point in matching_records
]
logger.debug(
f"Found {len(formatted_results)} records with prefix '{prefix}'"
)
return formatted_results
except Exception as e:
logger.error(f"Error searching for prefix '{prefix}': {e}")
return []

View File

@@ -414,6 +414,55 @@ class TiDBVectorDBStorage(BaseVectorStorage):
# Ti handles persistence automatically # Ti handles persistence automatically
pass pass
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
# Determine which table to query based on namespace
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
sql_template = """
SELECT entity_id as id, name as entity_name, entity_type, description, content
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
sql_template = """
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
keywords, description, content
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
sql_template = """
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
"""
else:
logger.warning(
f"Namespace {self.namespace} not supported for prefix search"
)
return []
# Add prefix pattern parameter with % for SQL LIKE
prefix_pattern = f"{prefix}%"
params = {"prefix_pattern": prefix_pattern, "workspace": self.db.workspace}
try:
results = await self.db.query(sql_template, params=params, multirows=True)
logger.debug(
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
)
return results if results else []
except Exception as e:
logger.error(f"Error searching records with prefix '{prefix}': {e}")
return []
@final @final
@dataclass @dataclass
@@ -968,4 +1017,20 @@ SQL_TEMPLATES = {
WHERE (source_name = :source AND target_name = :target) WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace AND workspace = :workspace
""", """,
# Search by prefix SQL templates
"search_entity_by_prefix": """
SELECT entity_id as id, name as entity_name, entity_type, description, content
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
""",
"search_relationship_by_prefix": """
SELECT relation_id as id, source_name as src_id, target_name as tgt_id, keywords, description, content
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
""",
"search_chunk_by_prefix": """
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
""",
} }

View File

@@ -504,11 +504,39 @@ class LightRAG:
return text return text
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
return await self.chunk_entity_relation_graph.get_knowledge_graph( """Get knowledge graph for a given label
node_label=node_label, max_depth=max_depth
) Args:
node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges
"""
# get params supported by get_knowledge_graph of specified storage
import inspect
storage_params = inspect.signature(
self.chunk_entity_relation_graph.get_knowledge_graph
).parameters
kwargs = {"node_label": node_label, "max_depth": max_depth}
if "min_degree" in storage_params and min_degree > 0:
kwargs["min_degree"] = min_degree
if "inclusive" in storage_params:
kwargs["inclusive"] = inclusive
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name] import_path = STORAGES[storage_name]
@@ -2016,6 +2044,9 @@ class LightRAG:
# Delete old entity record from vector database # Delete old entity record from vector database
old_entity_id = compute_mdhash_id(entity_name, prefix="ent-") old_entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([old_entity_id]) await self.entities_vdb.delete([old_entity_id])
logger.info(
f"Deleted old entity '{entity_name}' and its vector embedding from database"
)
# Update relationship vector representations # Update relationship vector representations
for src, tgt, edge_data in relations_to_update: for src, tgt, edge_data in relations_to_update:
@@ -2143,6 +2174,15 @@ class LightRAG:
f"Relation from '{source_entity}' to '{target_entity}' does not exist" f"Relation from '{source_entity}' to '{target_entity}' does not exist"
) )
# Important: First delete the old relation record from the vector database
old_relation_id = compute_mdhash_id(
source_entity + target_entity, prefix="rel-"
)
await self.relationships_vdb.delete([old_relation_id])
logger.info(
f"Deleted old relation record from vector database for relation {source_entity} -> {target_entity}"
)
# 2. Update relation information in the graph # 2. Update relation information in the graph
new_edge_data = {**edge_data, **updated_data} new_edge_data = {**edge_data, **updated_data}
await self.chunk_entity_relation_graph.upsert_edge( await self.chunk_entity_relation_graph.upsert_edge(
@@ -2641,12 +2681,29 @@ class LightRAG:
# 9. Delete source entities # 9. Delete source entities
for entity_name in source_entities: for entity_name in source_entities:
# Delete entity node # Delete entity node from knowledge graph
await self.chunk_entity_relation_graph.delete_node(entity_name) await self.chunk_entity_relation_graph.delete_node(entity_name)
# Delete record from vector database
# Delete entity record from vector database
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([entity_id]) await self.entities_vdb.delete([entity_id])
logger.info(f"Deleted source entity '{entity_name}'")
# Also ensure any relationships specific to this entity are deleted from vector DB
# This is a safety check, as these should have been transformed to the target entity already
entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-")
relations_with_entity = await self.relationships_vdb.search_by_prefix(
entity_relation_prefix
)
if relations_with_entity:
relation_ids = [r["id"] for r in relations_with_entity]
await self.relationships_vdb.delete(relation_ids)
logger.info(
f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database"
)
logger.info(
f"Deleted source entity '{entity_name}' and its vector embedding from database"
)
# 10. Save changes # 10. Save changes
await self._merge_entities_done() await self._merge_entities_done()

View File

@@ -1152,7 +1152,8 @@ async def _get_node_data(
"entity", "entity",
"type", "type",
"description", "description",
"rank" "created_at", "rank",
"created_at",
] ]
] ]
for i, n in enumerate(node_datas): for i, n in enumerate(node_datas):

View File

@@ -161,8 +161,12 @@ axiosInstance.interceptors.response.use(
) )
// API methods // API methods
export const queryGraphs = async (label: string, maxDepth: number): Promise<LightragGraphType> => { export const queryGraphs = async (
const response = await axiosInstance.get(`/graphs?label=${label}&max_depth=${maxDepth}`) label: string,
maxDepth: number,
minDegree: number
): Promise<LightragGraphType> => {
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
return response.data return response.data
} }

View File

@@ -40,18 +40,21 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
const focusedEdge = useGraphStore.use.focusedEdge() const focusedEdge = useGraphStore.use.focusedEdge()
/** /**
* When component mount * When component mount or maxIterations changes
* => load the graph * => load the graph and apply layout
*/ */
useEffect(() => { useEffect(() => {
// Create & load the graph // Create & load the graph
const graph = lightrageGraph() const graph = lightrageGraph()
loadGraph(graph) loadGraph(graph)
if (!(graph as any).__force_applied) {
assignLayout() assignLayout()
Object.assign(graph, { __force_applied: true }) }, [assignLayout, loadGraph, lightrageGraph, maxIterations])
}
/**
* When component mount
* => register events
*/
useEffect(() => {
const { setFocusedNode, setSelectedNode, setFocusedEdge, setSelectedEdge, clearSelection } = const { setFocusedNode, setSelectedNode, setFocusedEdge, setSelectedEdge, clearSelection } =
useGraphStore.getState() useGraphStore.getState()
@@ -87,7 +90,7 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
}, },
clickStage: () => clearSelection() clickStage: () => clearSelection()
}) })
}, [assignLayout, loadGraph, registerEvents, lightrageGraph]) }, [registerEvents])
/** /**
* When component mount or hovered node change * When component mount or hovered node change

View File

@@ -90,9 +90,12 @@ const LabeledNumberInput = ({
{label} {label}
</label> </label>
<Input <Input
value={currentValue || ''} type="number"
value={currentValue === null ? '' : currentValue}
onChange={onValueChange} onChange={onValueChange}
className="h-6 w-full min-w-0" className="h-6 w-full min-w-0 pr-1"
min={min}
max={max}
onBlur={onBlur} onBlur={onBlur}
onKeyDown={(e) => { onKeyDown={(e) => {
if (e.key === 'Enter') { if (e.key === 'Enter') {
@@ -119,6 +122,7 @@ export default function Settings() {
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges() const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
const showEdgeLabel = useSettingsStore.use.showEdgeLabel() const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth() const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
const graphMinDegree = useSettingsStore.use.graphMinDegree()
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations() const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
const enableHealthCheck = useSettingsStore.use.enableHealthCheck() const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
@@ -177,6 +181,11 @@ export default function Settings() {
useSettingsStore.setState({ graphQueryMaxDepth: depth }) useSettingsStore.setState({ graphQueryMaxDepth: depth })
}, []) }, [])
const setGraphMinDegree = useCallback((degree: number) => {
if (degree < 0) return
useSettingsStore.setState({ graphMinDegree: degree })
}, [])
const setGraphLayoutMaxIterations = useCallback((iterations: number) => { const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
if (iterations < 1) return if (iterations < 1) return
useSettingsStore.setState({ graphLayoutMaxIterations: iterations }) useSettingsStore.setState({ graphLayoutMaxIterations: iterations })
@@ -266,6 +275,12 @@ export default function Settings() {
value={graphQueryMaxDepth} value={graphQueryMaxDepth}
onEditFinished={setGraphQueryMaxDepth} onEditFinished={setGraphQueryMaxDepth}
/> />
<LabeledNumberInput
label="Minimum Degree"
min={0}
value={graphMinDegree}
onEditFinished={setGraphMinDegree}
/>
<LabeledNumberInput <LabeledNumberInput
label="Max Layout Iterations" label="Max Layout Iterations"
min={1} min={1}

View File

@@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<'input'>>(
<input <input
type={type} type={type}
className={cn( className={cn(
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm', 'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-100 [&::-webkit-outer-spin-button]:opacity-100',
className className
)} )}
ref={ref} ref={ref}

View File

@@ -50,11 +50,11 @@ export type NodeType = {
} }
export type EdgeType = { label: string } export type EdgeType = { label: string }
const fetchGraph = async (label: string, maxDepth: number) => { const fetchGraph = async (label: string, maxDepth: number, minDegree: number) => {
let rawData: any = null let rawData: any = null
try { try {
rawData = await queryGraphs(label, maxDepth) rawData = await queryGraphs(label, maxDepth, minDegree)
} catch (e) { } catch (e) {
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!') useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!')
return null return null
@@ -161,13 +161,14 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
return graph return graph
} }
const lastQueryLabel = { label: '', maxQueryDepth: 0 } const lastQueryLabel = { label: '', maxQueryDepth: 0, minDegree: 0 }
const useLightrangeGraph = () => { const useLightrangeGraph = () => {
const queryLabel = useSettingsStore.use.queryLabel() const queryLabel = useSettingsStore.use.queryLabel()
const rawGraph = useGraphStore.use.rawGraph() const rawGraph = useGraphStore.use.rawGraph()
const sigmaGraph = useGraphStore.use.sigmaGraph() const sigmaGraph = useGraphStore.use.sigmaGraph()
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth() const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
const minDegree = useSettingsStore.use.graphMinDegree()
const getNode = useCallback( const getNode = useCallback(
(nodeId: string) => { (nodeId: string) => {
@@ -185,13 +186,16 @@ const useLightrangeGraph = () => {
useEffect(() => { useEffect(() => {
if (queryLabel) { if (queryLabel) {
if (lastQueryLabel.label !== queryLabel || lastQueryLabel.maxQueryDepth !== maxQueryDepth) { if (lastQueryLabel.label !== queryLabel ||
lastQueryLabel.maxQueryDepth !== maxQueryDepth ||
lastQueryLabel.minDegree !== minDegree) {
lastQueryLabel.label = queryLabel lastQueryLabel.label = queryLabel
lastQueryLabel.maxQueryDepth = maxQueryDepth lastQueryLabel.maxQueryDepth = maxQueryDepth
lastQueryLabel.minDegree = minDegree
const state = useGraphStore.getState() const state = useGraphStore.getState()
state.reset() state.reset()
fetchGraph(queryLabel, maxQueryDepth).then((data) => { fetchGraph(queryLabel, maxQueryDepth, minDegree).then((data) => {
// console.debug('Query label: ' + queryLabel) // console.debug('Query label: ' + queryLabel)
state.setSigmaGraph(createSigmaGraph(data)) state.setSigmaGraph(createSigmaGraph(data))
data?.buildDynamicMap() data?.buildDynamicMap()
@@ -203,7 +207,7 @@ const useLightrangeGraph = () => {
state.reset() state.reset()
state.setSigmaGraph(new DirectedGraph()) state.setSigmaGraph(new DirectedGraph())
} }
}, [queryLabel, maxQueryDepth]) }, [queryLabel, maxQueryDepth, minDegree])
const lightrageGraph = useCallback(() => { const lightrageGraph = useCallback(() => {
if (sigmaGraph) { if (sigmaGraph) {

View File

@@ -22,6 +22,9 @@ interface SettingsState {
graphQueryMaxDepth: number graphQueryMaxDepth: number
setGraphQueryMaxDepth: (depth: number) => void setGraphQueryMaxDepth: (depth: number) => void
graphMinDegree: number
setGraphMinDegree: (degree: number) => void
graphLayoutMaxIterations: number graphLayoutMaxIterations: number
setGraphLayoutMaxIterations: (iterations: number) => void setGraphLayoutMaxIterations: (iterations: number) => void
@@ -66,6 +69,7 @@ const useSettingsStoreBase = create<SettingsState>()(
enableEdgeEvents: false, enableEdgeEvents: false,
graphQueryMaxDepth: 3, graphQueryMaxDepth: 3,
graphMinDegree: 0,
graphLayoutMaxIterations: 10, graphLayoutMaxIterations: 10,
queryLabel: defaultQueryLabel, queryLabel: defaultQueryLabel,
@@ -107,6 +111,8 @@ const useSettingsStoreBase = create<SettingsState>()(
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }), setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
setGraphMinDegree: (degree: number) => set({ graphMinDegree: degree }),
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }), setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
setApiKey: (apiKey: string | null) => set({ apiKey }), setApiKey: (apiKey: string | null) => set({ apiKey }),

View File

@@ -1 +1,11 @@
/// <reference types="vite/client" /> /// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_API_PROXY: string
readonly VITE_API_ENDPOINTS: string
readonly VITE_BACKEND_URL: string
}
interface ImportMeta {
readonly env: ImportMetaEnv
}

View File

@@ -26,5 +26,5 @@
"@/*": ["./src/*"] "@/*": ["./src/*"]
} }
}, },
"include": ["src"] "include": ["src", "vite.config.ts"]
} }

View File

@@ -14,6 +14,21 @@ export default defineConfig({
}, },
base: './', base: './',
build: { build: {
outDir: path.resolve(__dirname, '../lightrag/api/webui') outDir: path.resolve(__dirname, '../lightrag/api/webui'),
emptyOutDir: true
},
server: {
proxy: import.meta.env.VITE_API_PROXY === 'true' && import.meta.env.VITE_API_ENDPOINTS ?
Object.fromEntries(
import.meta.env.VITE_API_ENDPOINTS.split(',').map(endpoint => [
endpoint,
{
target: import.meta.env.VITE_BACKEND_URL || 'http://localhost:9621',
changeOrigin: true,
rewrite: endpoint === '/api' ?
(path) => path.replace(/^\/api/, '') : undefined
}
])
) : {}
} }
}) })