Merge branch 'main' into main
This commit is contained in:
@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
|
||||
|
||||
### Redis
|
||||
REDIS_URI=redis://localhost:6379
|
||||
|
||||
# For jwt auth
|
||||
AUTH_USERNAME=admin # login name
|
||||
AUTH_PASSWORD=admin123 # password
|
||||
TOKEN_SECRET=your-key # JWT key
|
||||
TOKEN_EXPIRE_HOURS=4 # expire duration
|
||||
WHITELIST_PATHS=/login,/health # white list
|
||||
|
51
examples/test_postgres.py
Normal file
51
examples/test_postgres.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import asyncio
|
||||
from lightrag.kg.postgres_impl import PGGraphStorage
|
||||
from lightrag.llm.ollama import ollama_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
|
||||
#########
|
||||
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
#########
|
||||
|
||||
WORKING_DIR = "./local_neo4jWorkDir"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
# AGE
|
||||
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
||||
|
||||
os.environ["POSTGRES_HOST"] = "localhost"
|
||||
os.environ["POSTGRES_PORT"] = "15432"
|
||||
os.environ["POSTGRES_USER"] = "rag"
|
||||
os.environ["POSTGRES_PASSWORD"] = "rag"
|
||||
os.environ["POSTGRES_DATABASE"] = "rag"
|
||||
|
||||
|
||||
async def main():
|
||||
graph_db = PGGraphStorage(
|
||||
namespace="dickens",
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: ollama_embedding(
|
||||
texts, embed_model="bge-m3", host="http://localhost:11434"
|
||||
),
|
||||
),
|
||||
global_config={},
|
||||
)
|
||||
await graph_db.initialize()
|
||||
labels = await graph_db.get_all_labels()
|
||||
print("all labels", labels)
|
||||
|
||||
res = await graph_db.get_knowledge_graph("FEZZIWIG")
|
||||
print("knowledge graphs", res)
|
||||
|
||||
await graph_db.finalize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
|
||||
pip install lightrag-hku
|
||||
```
|
||||
|
||||
## Authentication Endpoints
|
||||
|
||||
### JWT Authentication Mechanism
|
||||
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
|
||||
```bash
|
||||
# For jwt auth
|
||||
AUTH_USERNAME=admin # login name
|
||||
AUTH_PASSWORD=admin123 # password
|
||||
TOKEN_SECRET=your-key # JWT key
|
||||
TOKEN_EXPIRE_HOURS=4 # expire duration
|
||||
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
|
||||
|
41
lightrag/api/auth.py
Normal file
41
lightrag/api/auth.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str
|
||||
exp: datetime
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
def __init__(self):
|
||||
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
|
||||
self.algorithm = "HS256"
|
||||
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
|
||||
|
||||
def create_token(self, username: str) -> str:
|
||||
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
|
||||
payload = TokenPayload(sub=username, exp=expire)
|
||||
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
|
||||
|
||||
def validate_token(self, token: str) -> str:
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
expire_timestamp = payload["exp"]
|
||||
expire_time = datetime.utcfromtimestamp(expire_timestamp)
|
||||
|
||||
if datetime.utcnow() > expire_time:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
)
|
||||
return payload["sub"]
|
||||
except jwt.PyJWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
|
||||
auth_handler = AuthHandler()
|
@@ -2,10 +2,7 @@
|
||||
LightRAG FastAPI Server
|
||||
"""
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Depends,
|
||||
)
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
|
||||
initialize_pipeline_status,
|
||||
get_all_update_flags_status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from .auth import auth_handler
|
||||
|
||||
# Load environment variables
|
||||
# Updated to use the .env that is inside the current folder
|
||||
@@ -372,6 +371,27 @@ def create_app(args):
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||
app.include_router(ollama_api.router, prefix="/api")
|
||||
|
||||
@app.post("/login")
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username = os.getenv("AUTH_USERNAME")
|
||||
password = os.getenv("AUTH_PASSWORD")
|
||||
|
||||
if not (username and password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
if form_data.username != username or form_data.password != password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": auth_handler.create_token(username),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
|
@@ -1,10 +1,20 @@
|
||||
aiofiles
|
||||
ascii_colors
|
||||
asyncpg
|
||||
distro
|
||||
fastapi
|
||||
httpcore
|
||||
httpx
|
||||
jiter
|
||||
numpy
|
||||
openai
|
||||
passlib[bcrypt]
|
||||
pipmaster
|
||||
PyJWT
|
||||
python-dotenv
|
||||
python-jose[cryptography]
|
||||
python-multipart
|
||||
pytz
|
||||
tenacity
|
||||
tiktoken
|
||||
uvicorn
|
||||
|
@@ -18,8 +18,11 @@ from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.api.utils_api import get_api_key_dependency, global_args
|
||||
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
tags=["documents"],
|
||||
dependencies=[Depends(get_auth_dependency())],
|
||||
)
|
||||
|
||||
# Temporary file prefix
|
||||
temp_prefix = "__tmp__"
|
||||
|
@@ -3,12 +3,11 @@ This module contains all graph-related routes for the LightRAG API.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from ..utils_api import get_api_key_dependency
|
||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||
|
||||
router = APIRouter(tags=["graph"])
|
||||
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
|
||||
|
||||
|
||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
@@ -25,23 +24,33 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
return await rag.get_graph_labels()
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
||||
async def get_knowledge_graph(label: str, max_depth: int = 3):
|
||||
async def get_knowledge_graph(
|
||||
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||
):
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||
|
||||
Args:
|
||||
label (str): Label to get knowledge graph for
|
||||
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
|
||||
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
|
||||
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Knowledge graph for label
|
||||
"""
|
||||
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
|
||||
return await rag.get_knowledge_graph(
|
||||
node_label=label,
|
||||
max_depth=max_depth,
|
||||
inclusive=inclusive,
|
||||
min_degree=min_degree,
|
||||
)
|
||||
|
||||
return router
|
||||
|
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from lightrag.base import QueryParam
|
||||
from ..utils_api import get_api_key_dependency
|
||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from ascii_colors import trace_exception
|
||||
|
||||
router = APIRouter(tags=["query"])
|
||||
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
|
@@ -9,10 +9,11 @@ import sys
|
||||
import logging
|
||||
from ascii_colors import ASCIIColors
|
||||
from lightrag.api import __api_version__
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi import HTTPException, Security, Depends, Request
|
||||
from dotenv import load_dotenv
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from .auth import auth_handler
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
@@ -33,6 +34,24 @@ class OllamaServerInfos:
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
|
||||
def get_auth_dependency():
|
||||
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
|
||||
|
||||
async def dependency(
|
||||
request: Request,
|
||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
|
||||
):
|
||||
if request.url.path in whitelist:
|
||||
return
|
||||
|
||||
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
|
||||
return
|
||||
|
||||
auth_handler.validate_token(token)
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
"""
|
||||
Create an API key dependency for route protection.
|
||||
|
1
lightrag/api/webui/assets/index-CH-3l4_Z.css
Normal file
1
lightrag/api/webui/assets/index-CH-3l4_Z.css
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -5,8 +5,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="./logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lightrag</title>
|
||||
<script type="module" crossorigin src="./assets/index-DbuMPJAD.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-rP-YlyR1.css">
|
||||
<script type="module" crossorigin src="./assets/index-CJz72b6Q.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-CH-3l4_Z.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
@@ -204,7 +204,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
self, node_label: str, max_depth: int = 3
|
||||
) -> KnowledgeGraph:
|
||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||
|
||||
|
@@ -229,3 +229,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
raise
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
try:
|
||||
# Get all records from the collection
|
||||
# Since ChromaDB doesn't directly support prefix search on IDs,
|
||||
# we'll get all records and filter in Python
|
||||
results = self._collection.get(
|
||||
include=["metadatas", "documents", "embeddings"]
|
||||
)
|
||||
|
||||
matching_records = []
|
||||
|
||||
# Filter records where ID starts with the prefix
|
||||
for i, record_id in enumerate(results["ids"]):
|
||||
if record_id.startswith(prefix):
|
||||
matching_records.append(
|
||||
{
|
||||
"id": record_id,
|
||||
"content": results["documents"][i],
|
||||
"vector": results["embeddings"][i],
|
||||
**results["metadatas"][i],
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Found {len(matching_records)} records with prefix '{prefix}'"
|
||||
)
|
||||
return matching_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
|
||||
raise
|
||||
|
@@ -371,3 +371,24 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
return False # Return error
|
||||
|
||||
return True # Return success
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
matching_records = []
|
||||
|
||||
# Search for records with IDs starting with the prefix
|
||||
for faiss_id, meta in self._id_to_meta.items():
|
||||
if "__id__" in meta and meta["__id__"].startswith(prefix):
|
||||
# Create a copy of all metadata and add "id" field
|
||||
record = {**meta, "id": meta["__id__"]}
|
||||
matching_records.append(record)
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
@@ -206,3 +206,28 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
try:
|
||||
# Use Milvus query with expression to find IDs with the given prefix
|
||||
expression = f'id like "{prefix}%"'
|
||||
results = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=expression,
|
||||
output_fields=list(self.meta_fields) + ["id"],
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
@@ -1045,6 +1045,32 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {str(e)}")
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
try:
|
||||
# Use MongoDB regex to find documents where _id starts with the prefix
|
||||
cursor = self._data.find({"_id": {"$regex": f"^{prefix}"}})
|
||||
matching_records = await cursor.to_list(length=None)
|
||||
|
||||
# Format results
|
||||
results = [{**doc, "id": doc["_id"]} for doc in matching_records]
|
||||
|
||||
logger.debug(
|
||||
f"Found {len(results)} records with prefix '{prefix}' in {self.namespace}"
|
||||
)
|
||||
return results
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
@@ -236,3 +236,23 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
return False # Return error
|
||||
|
||||
return True # Return success
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
storage = await self.client_storage
|
||||
matching_records = []
|
||||
|
||||
# Search for records with IDs starting with the prefix
|
||||
for record in storage["data"]:
|
||||
if "__id__" in record and record["__id__"].startswith(prefix):
|
||||
matching_records.append({**record, "id": record["__id__"]})
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
@@ -232,19 +232,26 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return sorted(list(labels))
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
min_degree: Minimum degree of nodes to include. Defaults to 0
|
||||
inclusive: Do an inclusive search if true
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges
|
||||
@@ -255,6 +262,10 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Initialize sets for start nodes and direct connected nodes
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
@@ -262,10 +273,15 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
else:
|
||||
# Find nodes with matching node id (partial match)
|
||||
# Find nodes with matching node id based on search_mode
|
||||
nodes_to_explore = []
|
||||
for n, attr in graph.nodes(data=True):
|
||||
if node_label in str(n): # Use partial matching
|
||||
node_str = str(n)
|
||||
if not inclusive:
|
||||
if node_label == node_str: # Use exact matching
|
||||
nodes_to_explore.append(n)
|
||||
else: # inclusive mode
|
||||
if node_label in node_str: # Use partial matching
|
||||
nodes_to_explore.append(n)
|
||||
|
||||
if not nodes_to_explore:
|
||||
@@ -277,26 +293,37 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
for start_node in nodes_to_explore:
|
||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||
|
||||
# Get start nodes and direct connected nodes
|
||||
if nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(
|
||||
combined_subgraph.neighbors(start_node)
|
||||
)
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
|
||||
subgraph = combined_subgraph
|
||||
|
||||
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||
if min_degree > 0:
|
||||
nodes_to_keep = [
|
||||
node
|
||||
for node, degree in subgraph.degree()
|
||||
if node in start_nodes
|
||||
or node in direct_connected_nodes
|
||||
or degree >= min_degree
|
||||
]
|
||||
subgraph = subgraph.subgraph(nodes_to_keep)
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
|
||||
node_degrees = dict(subgraph.degree())
|
||||
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
|
||||
if node_label != "*" and nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(subgraph.neighbors(start_node))
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
|
||||
def priority_key(node_item):
|
||||
node, degree = node_item
|
||||
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
||||
@@ -356,7 +383,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="RELATED",
|
||||
type="DIRECTED",
|
||||
source=str(source),
|
||||
target=str(target),
|
||||
properties=edge_data,
|
||||
|
@@ -494,6 +494,41 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
||||
raise
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
try:
|
||||
# Determine the appropriate table based on namespace
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
|
||||
# Create SQL query to find records with IDs starting with prefix
|
||||
search_sql = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE workspace = :workspace
|
||||
AND id LIKE :prefix_pattern
|
||||
ORDER BY id
|
||||
"""
|
||||
|
||||
params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
|
||||
|
||||
# Execute query and get results
|
||||
results = await self.db.query(search_sql, params, multirows=True)
|
||||
|
||||
logger.debug(
|
||||
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
|
||||
)
|
||||
return results or []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -575,6 +575,41 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for prefix search: {self.namespace}")
|
||||
return []
|
||||
|
||||
search_sql = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id LIKE $2"
|
||||
params = {"workspace": self.db.workspace, "prefix": f"{prefix}%"}
|
||||
|
||||
try:
|
||||
results = await self.db.query(search_sql, params, multirows=True)
|
||||
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
|
||||
|
||||
# Format results to match the expected return format
|
||||
formatted_results = []
|
||||
for record in results:
|
||||
formatted_record = dict(record)
|
||||
# Ensure id field is available (for consistency with NanoVectorDB implementation)
|
||||
if "id" not in formatted_record:
|
||||
formatted_record["id"] = record["id"]
|
||||
formatted_results.append(formatted_record)
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error during prefix search for '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -775,6 +810,14 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
v = record[k]
|
||||
# agtype comes back '{key: value}::type' which must be parsed
|
||||
if isinstance(v, str) and "::" in v:
|
||||
if v.startswith("[") and v.endswith("]"):
|
||||
if "::vertex" not in v:
|
||||
continue
|
||||
v = v.replace("::vertex", "")
|
||||
vertexes = json.loads(v)
|
||||
for vertex in vertexes:
|
||||
vertices[vertex["id"]] = vertex.get("properties")
|
||||
else:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
if dtype == "vertex":
|
||||
@@ -785,17 +828,49 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
for k in record.keys():
|
||||
v = record[k]
|
||||
if isinstance(v, str) and "::" in v:
|
||||
if v.startswith("[") and v.endswith("]"):
|
||||
if "::vertex" in v:
|
||||
v = v.replace("::vertex", "")
|
||||
vertexes = json.loads(v)
|
||||
dl = []
|
||||
for vertex in vertexes:
|
||||
prop = vertex.get("properties")
|
||||
if not prop:
|
||||
prop = {}
|
||||
prop["label"] = PGGraphStorage._decode_graph_label(
|
||||
prop["node_id"]
|
||||
)
|
||||
dl.append(prop)
|
||||
d[k] = dl
|
||||
|
||||
elif "::edge" in v:
|
||||
v = v.replace("::edge", "")
|
||||
edges = json.loads(v)
|
||||
dl = []
|
||||
for edge in edges:
|
||||
dl.append(
|
||||
(
|
||||
vertices[edge["start_id"]],
|
||||
edge["label"],
|
||||
vertices[edge["end_id"]],
|
||||
)
|
||||
)
|
||||
d[k] = dl
|
||||
else:
|
||||
print("WARNING: unsupported type")
|
||||
continue
|
||||
|
||||
else:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
else:
|
||||
dtype = ""
|
||||
|
||||
if dtype == "vertex":
|
||||
vertex = json.loads(v)
|
||||
field = vertex.get("properties")
|
||||
if not field:
|
||||
field = {}
|
||||
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
|
||||
field["label"] = PGGraphStorage._decode_graph_label(
|
||||
field["node_id"]
|
||||
)
|
||||
d[k] = field
|
||||
# convert edge from id-label->id by replacing id with node information
|
||||
# we only do this if the vertex was also returned in the query
|
||||
@@ -809,6 +884,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
||||
vertices.get(edge["end_id"], {}),
|
||||
)
|
||||
else:
|
||||
if v is None or (v.count("{") < 1 and v.count("[") < 1):
|
||||
d[k] = v
|
||||
else:
|
||||
d[k] = json.loads(v) if isinstance(v, str) else v
|
||||
|
||||
@@ -1284,7 +1362,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
||||
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
||||
LIMIT %d
|
||||
$$) AS (nodes agtype[], relationships agtype[])""" % (
|
||||
$$) AS (nodes agtype, relationships agtype)""" % (
|
||||
self.graph_name,
|
||||
encoded_node_label,
|
||||
max_depth,
|
||||
@@ -1293,17 +1371,23 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
nodes = set()
|
||||
nodes = {}
|
||||
edges = []
|
||||
unique_edge_ids = set()
|
||||
|
||||
for result in results:
|
||||
if node_label == "*":
|
||||
if result["n"]:
|
||||
node = result["n"]
|
||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
||||
node_id = self._decode_graph_label(node["node_id"])
|
||||
if node_id not in nodes:
|
||||
nodes[node_id] = node
|
||||
|
||||
if result["m"]:
|
||||
node = result["m"]
|
||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
||||
node_id = self._decode_graph_label(node["node_id"])
|
||||
if node_id not in nodes:
|
||||
nodes[node_id] = node
|
||||
if result["r"]:
|
||||
edge = result["r"]
|
||||
src_id = self._decode_graph_label(edge["start_id"])
|
||||
@@ -1312,16 +1396,36 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
if result["nodes"]:
|
||||
for node in result["nodes"]:
|
||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
||||
node_id = self._decode_graph_label(node["node_id"])
|
||||
if node_id not in nodes:
|
||||
nodes[node_id] = node
|
||||
|
||||
if result["relationships"]:
|
||||
for edge in result["relationships"]:
|
||||
src_id = self._decode_graph_label(edge["start_id"])
|
||||
tgt_id = self._decode_graph_label(edge["end_id"])
|
||||
edges.append((src_id, tgt_id))
|
||||
for edge in result["relationships"]: # src --DIRECTED--> target
|
||||
src_id = self._decode_graph_label(edge[0]["node_id"])
|
||||
tgt_id = self._decode_graph_label(edge[2]["node_id"])
|
||||
id = src_id + "," + tgt_id
|
||||
if id in unique_edge_ids:
|
||||
continue
|
||||
else:
|
||||
unique_edge_ids.add(id)
|
||||
edges.append(
|
||||
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
|
||||
)
|
||||
|
||||
kg = KnowledgeGraph(
|
||||
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
|
||||
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
|
||||
nodes=[
|
||||
KnowledgeGraphNode(
|
||||
id=node_id, labels=[node_id], properties=nodes[node_id]
|
||||
)
|
||||
for node_id in nodes
|
||||
],
|
||||
edges=[
|
||||
KnowledgeGraphEdge(
|
||||
id=id, type="DIRECTED", source=src, target=tgt, properties=props
|
||||
)
|
||||
for id, src, tgt, props in edges
|
||||
],
|
||||
)
|
||||
|
||||
return kg
|
||||
|
@@ -233,3 +233,45 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug(f"No relations found for entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
try:
|
||||
# Use scroll method to find records with IDs starting with the prefix
|
||||
results = self._client.scroll(
|
||||
collection_name=self.namespace,
|
||||
scroll_filter=models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="id", match=models.MatchText(text=prefix, prefix=True)
|
||||
)
|
||||
]
|
||||
),
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
limit=1000, # Adjust as needed for your use case
|
||||
)
|
||||
|
||||
# Extract matching points
|
||||
matching_records = results[0]
|
||||
|
||||
# Format the results to match expected return format
|
||||
formatted_results = [
|
||||
{**point.payload, "id": point.id} for point in matching_records
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Found {len(formatted_results)} records with prefix '{prefix}'"
|
||||
)
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
@@ -414,6 +414,55 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to search for in record IDs
|
||||
|
||||
Returns:
|
||||
List of records with matching ID prefixes
|
||||
"""
|
||||
# Determine which table to query based on namespace
|
||||
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
|
||||
sql_template = """
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
|
||||
sql_template = """
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
|
||||
keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
|
||||
sql_template = """
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
"""
|
||||
else:
|
||||
logger.warning(
|
||||
f"Namespace {self.namespace} not supported for prefix search"
|
||||
)
|
||||
return []
|
||||
|
||||
# Add prefix pattern parameter with % for SQL LIKE
|
||||
prefix_pattern = f"{prefix}%"
|
||||
params = {"prefix_pattern": prefix_pattern, "workspace": self.db.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(sql_template, params=params, multirows=True)
|
||||
logger.debug(
|
||||
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
|
||||
)
|
||||
return results if results else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -968,4 +1017,20 @@ SQL_TEMPLATES = {
|
||||
WHERE (source_name = :source AND target_name = :target)
|
||||
AND workspace = :workspace
|
||||
""",
|
||||
# Search by prefix SQL templates
|
||||
"search_entity_by_prefix": """
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
""",
|
||||
"search_relationship_by_prefix": """
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id, keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
""",
|
||||
"search_chunk_by_prefix": """
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
""",
|
||||
}
|
||||
|
@@ -504,11 +504,39 @@ class LightRAG:
|
||||
return text
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
) -> KnowledgeGraph:
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||
node_label=node_label, max_depth=max_depth
|
||||
)
|
||||
"""Get knowledge graph for a given label
|
||||
|
||||
Args:
|
||||
node_label (str): Label to get knowledge graph for
|
||||
max_depth (int): Maximum depth of graph
|
||||
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
|
||||
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Knowledge graph containing nodes and edges
|
||||
"""
|
||||
# get params supported by get_knowledge_graph of specified storage
|
||||
import inspect
|
||||
|
||||
storage_params = inspect.signature(
|
||||
self.chunk_entity_relation_graph.get_knowledge_graph
|
||||
).parameters
|
||||
|
||||
kwargs = {"node_label": node_label, "max_depth": max_depth}
|
||||
|
||||
if "min_degree" in storage_params and min_degree > 0:
|
||||
kwargs["min_degree"] = min_degree
|
||||
|
||||
if "inclusive" in storage_params:
|
||||
kwargs["inclusive"] = inclusive
|
||||
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
import_path = STORAGES[storage_name]
|
||||
@@ -2016,6 +2044,9 @@ class LightRAG:
|
||||
# Delete old entity record from vector database
|
||||
old_entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
await self.entities_vdb.delete([old_entity_id])
|
||||
logger.info(
|
||||
f"Deleted old entity '{entity_name}' and its vector embedding from database"
|
||||
)
|
||||
|
||||
# Update relationship vector representations
|
||||
for src, tgt, edge_data in relations_to_update:
|
||||
@@ -2143,6 +2174,15 @@ class LightRAG:
|
||||
f"Relation from '{source_entity}' to '{target_entity}' does not exist"
|
||||
)
|
||||
|
||||
# Important: First delete the old relation record from the vector database
|
||||
old_relation_id = compute_mdhash_id(
|
||||
source_entity + target_entity, prefix="rel-"
|
||||
)
|
||||
await self.relationships_vdb.delete([old_relation_id])
|
||||
logger.info(
|
||||
f"Deleted old relation record from vector database for relation {source_entity} -> {target_entity}"
|
||||
)
|
||||
|
||||
# 2. Update relation information in the graph
|
||||
new_edge_data = {**edge_data, **updated_data}
|
||||
await self.chunk_entity_relation_graph.upsert_edge(
|
||||
@@ -2641,12 +2681,29 @@ class LightRAG:
|
||||
|
||||
# 9. Delete source entities
|
||||
for entity_name in source_entities:
|
||||
# Delete entity node
|
||||
# Delete entity node from knowledge graph
|
||||
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
||||
# Delete record from vector database
|
||||
|
||||
# Delete entity record from vector database
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
await self.entities_vdb.delete([entity_id])
|
||||
logger.info(f"Deleted source entity '{entity_name}'")
|
||||
|
||||
# Also ensure any relationships specific to this entity are deleted from vector DB
|
||||
# This is a safety check, as these should have been transformed to the target entity already
|
||||
entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-")
|
||||
relations_with_entity = await self.relationships_vdb.search_by_prefix(
|
||||
entity_relation_prefix
|
||||
)
|
||||
if relations_with_entity:
|
||||
relation_ids = [r["id"] for r in relations_with_entity]
|
||||
await self.relationships_vdb.delete(relation_ids)
|
||||
logger.info(
|
||||
f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Deleted source entity '{entity_name}' and its vector embedding from database"
|
||||
)
|
||||
|
||||
# 10. Save changes
|
||||
await self._merge_entities_done()
|
||||
|
@@ -1152,7 +1152,8 @@ async def _get_node_data(
|
||||
"entity",
|
||||
"type",
|
||||
"description",
|
||||
"rank" "created_at",
|
||||
"rank",
|
||||
"created_at",
|
||||
]
|
||||
]
|
||||
for i, n in enumerate(node_datas):
|
||||
|
@@ -161,8 +161,12 @@ axiosInstance.interceptors.response.use(
|
||||
)
|
||||
|
||||
// API methods
|
||||
export const queryGraphs = async (label: string, maxDepth: number): Promise<LightragGraphType> => {
|
||||
const response = await axiosInstance.get(`/graphs?label=${label}&max_depth=${maxDepth}`)
|
||||
export const queryGraphs = async (
|
||||
label: string,
|
||||
maxDepth: number,
|
||||
minDegree: number
|
||||
): Promise<LightragGraphType> => {
|
||||
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
|
||||
return response.data
|
||||
}
|
||||
|
||||
|
@@ -40,18 +40,21 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
|
||||
const focusedEdge = useGraphStore.use.focusedEdge()
|
||||
|
||||
/**
|
||||
* When component mount
|
||||
* => load the graph
|
||||
* When component mount or maxIterations changes
|
||||
* => load the graph and apply layout
|
||||
*/
|
||||
useEffect(() => {
|
||||
// Create & load the graph
|
||||
const graph = lightrageGraph()
|
||||
loadGraph(graph)
|
||||
if (!(graph as any).__force_applied) {
|
||||
assignLayout()
|
||||
Object.assign(graph, { __force_applied: true })
|
||||
}
|
||||
}, [assignLayout, loadGraph, lightrageGraph, maxIterations])
|
||||
|
||||
/**
|
||||
* When component mount
|
||||
* => register events
|
||||
*/
|
||||
useEffect(() => {
|
||||
const { setFocusedNode, setSelectedNode, setFocusedEdge, setSelectedEdge, clearSelection } =
|
||||
useGraphStore.getState()
|
||||
|
||||
@@ -87,7 +90,7 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
|
||||
},
|
||||
clickStage: () => clearSelection()
|
||||
})
|
||||
}, [assignLayout, loadGraph, registerEvents, lightrageGraph])
|
||||
}, [registerEvents])
|
||||
|
||||
/**
|
||||
* When component mount or hovered node change
|
||||
|
@@ -90,9 +90,12 @@ const LabeledNumberInput = ({
|
||||
{label}
|
||||
</label>
|
||||
<Input
|
||||
value={currentValue || ''}
|
||||
type="number"
|
||||
value={currentValue === null ? '' : currentValue}
|
||||
onChange={onValueChange}
|
||||
className="h-6 w-full min-w-0"
|
||||
className="h-6 w-full min-w-0 pr-1"
|
||||
min={min}
|
||||
max={max}
|
||||
onBlur={onBlur}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
@@ -119,6 +122,7 @@ export default function Settings() {
|
||||
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
|
||||
const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
|
||||
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
|
||||
const graphMinDegree = useSettingsStore.use.graphMinDegree()
|
||||
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
|
||||
|
||||
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
||||
@@ -177,6 +181,11 @@ export default function Settings() {
|
||||
useSettingsStore.setState({ graphQueryMaxDepth: depth })
|
||||
}, [])
|
||||
|
||||
const setGraphMinDegree = useCallback((degree: number) => {
|
||||
if (degree < 0) return
|
||||
useSettingsStore.setState({ graphMinDegree: degree })
|
||||
}, [])
|
||||
|
||||
const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
|
||||
if (iterations < 1) return
|
||||
useSettingsStore.setState({ graphLayoutMaxIterations: iterations })
|
||||
@@ -266,6 +275,12 @@ export default function Settings() {
|
||||
value={graphQueryMaxDepth}
|
||||
onEditFinished={setGraphQueryMaxDepth}
|
||||
/>
|
||||
<LabeledNumberInput
|
||||
label="Minimum Degree"
|
||||
min={0}
|
||||
value={graphMinDegree}
|
||||
onEditFinished={setGraphMinDegree}
|
||||
/>
|
||||
<LabeledNumberInput
|
||||
label="Max Layout Iterations"
|
||||
min={1}
|
||||
|
@@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<'input'>>(
|
||||
<input
|
||||
type={type}
|
||||
className={cn(
|
||||
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm',
|
||||
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-100 [&::-webkit-outer-spin-button]:opacity-100',
|
||||
className
|
||||
)}
|
||||
ref={ref}
|
||||
|
@@ -50,11 +50,11 @@ export type NodeType = {
|
||||
}
|
||||
export type EdgeType = { label: string }
|
||||
|
||||
const fetchGraph = async (label: string, maxDepth: number) => {
|
||||
const fetchGraph = async (label: string, maxDepth: number, minDegree: number) => {
|
||||
let rawData: any = null
|
||||
|
||||
try {
|
||||
rawData = await queryGraphs(label, maxDepth)
|
||||
rawData = await queryGraphs(label, maxDepth, minDegree)
|
||||
} catch (e) {
|
||||
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!')
|
||||
return null
|
||||
@@ -161,13 +161,14 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
||||
return graph
|
||||
}
|
||||
|
||||
const lastQueryLabel = { label: '', maxQueryDepth: 0 }
|
||||
const lastQueryLabel = { label: '', maxQueryDepth: 0, minDegree: 0 }
|
||||
|
||||
const useLightrangeGraph = () => {
|
||||
const queryLabel = useSettingsStore.use.queryLabel()
|
||||
const rawGraph = useGraphStore.use.rawGraph()
|
||||
const sigmaGraph = useGraphStore.use.sigmaGraph()
|
||||
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
|
||||
const minDegree = useSettingsStore.use.graphMinDegree()
|
||||
|
||||
const getNode = useCallback(
|
||||
(nodeId: string) => {
|
||||
@@ -185,13 +186,16 @@ const useLightrangeGraph = () => {
|
||||
|
||||
useEffect(() => {
|
||||
if (queryLabel) {
|
||||
if (lastQueryLabel.label !== queryLabel || lastQueryLabel.maxQueryDepth !== maxQueryDepth) {
|
||||
if (lastQueryLabel.label !== queryLabel ||
|
||||
lastQueryLabel.maxQueryDepth !== maxQueryDepth ||
|
||||
lastQueryLabel.minDegree !== minDegree) {
|
||||
lastQueryLabel.label = queryLabel
|
||||
lastQueryLabel.maxQueryDepth = maxQueryDepth
|
||||
lastQueryLabel.minDegree = minDegree
|
||||
|
||||
const state = useGraphStore.getState()
|
||||
state.reset()
|
||||
fetchGraph(queryLabel, maxQueryDepth).then((data) => {
|
||||
fetchGraph(queryLabel, maxQueryDepth, minDegree).then((data) => {
|
||||
// console.debug('Query label: ' + queryLabel)
|
||||
state.setSigmaGraph(createSigmaGraph(data))
|
||||
data?.buildDynamicMap()
|
||||
@@ -203,7 +207,7 @@ const useLightrangeGraph = () => {
|
||||
state.reset()
|
||||
state.setSigmaGraph(new DirectedGraph())
|
||||
}
|
||||
}, [queryLabel, maxQueryDepth])
|
||||
}, [queryLabel, maxQueryDepth, minDegree])
|
||||
|
||||
const lightrageGraph = useCallback(() => {
|
||||
if (sigmaGraph) {
|
||||
|
@@ -22,6 +22,9 @@ interface SettingsState {
|
||||
graphQueryMaxDepth: number
|
||||
setGraphQueryMaxDepth: (depth: number) => void
|
||||
|
||||
graphMinDegree: number
|
||||
setGraphMinDegree: (degree: number) => void
|
||||
|
||||
graphLayoutMaxIterations: number
|
||||
setGraphLayoutMaxIterations: (iterations: number) => void
|
||||
|
||||
@@ -66,6 +69,7 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||
enableEdgeEvents: false,
|
||||
|
||||
graphQueryMaxDepth: 3,
|
||||
graphMinDegree: 0,
|
||||
graphLayoutMaxIterations: 10,
|
||||
|
||||
queryLabel: defaultQueryLabel,
|
||||
@@ -107,6 +111,8 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||
|
||||
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
|
||||
|
||||
setGraphMinDegree: (degree: number) => set({ graphMinDegree: degree }),
|
||||
|
||||
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
|
||||
|
||||
setApiKey: (apiKey: string | null) => set({ apiKey }),
|
||||
|
10
lightrag_webui/src/vite-env.d.ts
vendored
10
lightrag_webui/src/vite-env.d.ts
vendored
@@ -1 +1,11 @@
|
||||
/// <reference types="vite/client" />
|
||||
|
||||
interface ImportMetaEnv {
|
||||
readonly VITE_API_PROXY: string
|
||||
readonly VITE_API_ENDPOINTS: string
|
||||
readonly VITE_BACKEND_URL: string
|
||||
}
|
||||
|
||||
interface ImportMeta {
|
||||
readonly env: ImportMetaEnv
|
||||
}
|
||||
|
@@ -26,5 +26,5 @@
|
||||
"@/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["src"]
|
||||
"include": ["src", "vite.config.ts"]
|
||||
}
|
||||
|
@@ -14,6 +14,21 @@ export default defineConfig({
|
||||
},
|
||||
base: './',
|
||||
build: {
|
||||
outDir: path.resolve(__dirname, '../lightrag/api/webui')
|
||||
outDir: path.resolve(__dirname, '../lightrag/api/webui'),
|
||||
emptyOutDir: true
|
||||
},
|
||||
server: {
|
||||
proxy: import.meta.env.VITE_API_PROXY === 'true' && import.meta.env.VITE_API_ENDPOINTS ?
|
||||
Object.fromEntries(
|
||||
import.meta.env.VITE_API_ENDPOINTS.split(',').map(endpoint => [
|
||||
endpoint,
|
||||
{
|
||||
target: import.meta.env.VITE_BACKEND_URL || 'http://localhost:9621',
|
||||
changeOrigin: true,
|
||||
rewrite: endpoint === '/api' ?
|
||||
(path) => path.replace(/^\/api/, '') : undefined
|
||||
}
|
||||
])
|
||||
) : {}
|
||||
}
|
||||
})
|
||||
|
Reference in New Issue
Block a user