diff --git a/.env.example b/.env.example index 82b9ca70..0c61d2e0 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,17 @@ PORT=9621 WORKING_DIR=/app/data/rag_storage INPUT_DIR=/app/data/inputs +# RAG Configuration +MAX_ASYNC=4 +MAX_TOKENS=32768 +EMBEDDING_DIM=1024 +MAX_EMBED_TOKENS=8192 +#HISTORY_TURNS=3 +#CHUNK_SIZE=1200 +#CHUNK_OVERLAP_SIZE=100 +#COSINE_THRESHOLD=0.2 +#TOP_K=50 + # LLM Configuration (Use valid host. For local services, you can use host.docker.internal) # Ollama example LLM_BINDING=ollama @@ -38,15 +49,6 @@ EMBEDDING_MODEL=bge-m3:latest # EMBEDDING_BINDING_HOST=http://host.docker.internal:9600 # EMBEDDING_MODEL=bge-m3:latest -# RAG Configuration -MAX_ASYNC=4 -MAX_TOKENS=32768 -EMBEDDING_DIM=1024 -MAX_EMBED_TOKENS=8192 -#HISTORY_TURNS=3 -#CHUNK_SIZE=1200 -#CHUNK_OVERLAP_SIZE=100 - # Security (empty for no key) LIGHTRAG_API_KEY=your-secure-api-key-here diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 35e4acf7..2ab30d2b 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -207,8 +207,12 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.chunk_size}") ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="") ASCIIColors.yellow(f"{args.chunk_overlap_size}") - ASCIIColors.white(" └─ History Turns: ", end="") + ASCIIColors.white(" ├─ History Turns: ", end="") ASCIIColors.yellow(f"{args.history_turns}") + ASCIIColors.white(" ├─ Cosine Threshold: ", end="") + ASCIIColors.yellow(f"{args.cosine_threshold}") + ASCIIColors.white(" └─ Top-K: ", end="") + ASCIIColors.yellow(f"{args.top_k}") # System Configuration ASCIIColors.magenta("\n🛠️ System Configuration:") @@ -484,6 +488,20 @@ def parse_args() -> argparse.Namespace: help="Number of conversation history turns to include (default: from env or 3)", ) + # Search parameters + parser.add_argument( + "--top-k", + type=int, + default=get_env_value("TOP_K", 50, int), + help="Number of most similar results to return (default: from env or 50)", + ) + parser.add_argument( + "--cosine-threshold", + type=float, + default=get_env_value("COSINE_THRESHOLD", 0.4, float), + help="Cosine similarity threshold (default: from env or 0.4)", + ) + args = parser.parse_args() return args @@ -846,6 +864,9 @@ def create_app(args): graph_storage=GRAPH_STORAGE, vector_storage=VECTOR_STORAGE, doc_status_storage=DOC_STATUS_STORAGE, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, ) else: rag = LightRAG( @@ -863,6 +884,9 @@ def create_app(args): graph_storage=GRAPH_STORAGE, vector_storage=VECTOR_STORAGE, doc_status_storage=DOC_STATUS_STORAGE, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, ) async def index_file(file_path: Union[str, Path]) -> None: @@ -1052,6 +1076,7 @@ def create_app(args): mode=request.mode, stream=request.stream, only_need_context=request.only_need_context, + top_k=args.top_k, ), ) @@ -1093,6 +1118,7 @@ def create_app(args): mode=request.mode, stream=True, only_need_context=request.only_need_context, + top_k=args.top_k, ), ) @@ -1632,6 +1658,7 @@ def create_app(args): "stream": request.stream, "only_need_context": False, "conversation_history": conversation_history, + "top_k": args.top_k, } if args.history_turns is not None: diff --git a/lightrag/base.py b/lightrag/base.py index 36e70893..e71cac3f 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass, field from typing import ( TypedDict, @@ -32,7 +33,7 @@ class QueryParam: response_type: str = "Multiple Paragraphs" stream: bool = False # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. - top_k: int = 60 + top_k: int = int(os.getenv("TOP_K", "60")) # Number of document chunks to retrieve. # top_n: int = 10 # Number of tokens for the original chunks. diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 328a1242..b6650797 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -73,7 +73,7 @@ from lightrag.base import ( @dataclass class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 + cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): self._client_file_name = os.path.join(