Fix: top_k param handling error, unify top_k and cosine default value.
This commit is contained in:
@@ -32,8 +32,8 @@ MAX_EMBED_TOKENS=8192
|
|||||||
#HISTORY_TURNS=3
|
#HISTORY_TURNS=3
|
||||||
#CHUNK_SIZE=1200
|
#CHUNK_SIZE=1200
|
||||||
#CHUNK_OVERLAP_SIZE=100
|
#CHUNK_OVERLAP_SIZE=100
|
||||||
#COSINE_THRESHOLD=0.4 # 0.2 while not running API server
|
#COSINE_THRESHOLD=0.2
|
||||||
#TOP_K=50 # 60 while not running API server
|
#TOP_K=60
|
||||||
|
|
||||||
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
|
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
|
||||||
### Ollama example
|
### Ollama example
|
||||||
|
@@ -103,7 +103,7 @@ After starting the lightrag-server, you can add an Ollama-type connection in the
|
|||||||
|
|
||||||
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
|
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
|
||||||
|
|
||||||
For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM.
|
Default `TOP_K` is set to `60`. Default `COSINE_THRESHOLD` are set to `0.2`.
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
|
@@ -530,13 +530,13 @@ def parse_args() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top-k",
|
"--top-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=get_env_value("TOP_K", 50, int),
|
default=get_env_value("TOP_K", 60, int),
|
||||||
help="Number of most similar results to return (default: from env or 50)",
|
help="Number of most similar results to return (default: from env or 60)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cosine-threshold",
|
"--cosine-threshold",
|
||||||
type=float,
|
type=float,
|
||||||
default=get_env_value("COSINE_THRESHOLD", 0.4, float),
|
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
||||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -669,7 +669,13 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
return api_key_auth
|
return api_key_auth
|
||||||
|
|
||||||
|
|
||||||
|
# Global configuration
|
||||||
|
global_top_k = 60 # default value
|
||||||
|
|
||||||
def create_app(args):
|
def create_app(args):
|
||||||
|
global global_top_k
|
||||||
|
global_top_k = args.top_k # save top_k from args
|
||||||
|
|
||||||
# Verify that bindings are correctly setup
|
# Verify that bindings are correctly setup
|
||||||
if args.llm_binding not in [
|
if args.llm_binding not in [
|
||||||
"lollms",
|
"lollms",
|
||||||
@@ -1279,7 +1285,7 @@ def create_app(args):
|
|||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
only_need_context=request.only_need_context,
|
only_need_context=request.only_need_context,
|
||||||
top_k=args.top_k,
|
top_k=global_top_k,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1321,7 +1327,7 @@ def create_app(args):
|
|||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
stream=True,
|
stream=True,
|
||||||
only_need_context=request.only_need_context,
|
only_need_context=request.only_need_context,
|
||||||
top_k=args.top_k,
|
top_k=global_top_k,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1611,7 +1617,7 @@ def create_app(args):
|
|||||||
return await rag.get_graps(nodel_label=label, max_depth=100)
|
return await rag.get_graps(nodel_label=label, max_depth=100)
|
||||||
|
|
||||||
# Add Ollama API routes
|
# Add Ollama API routes
|
||||||
ollama_api = OllamaAPI(rag)
|
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||||
app.include_router(ollama_api.router, prefix="/api")
|
app.include_router(ollama_api.router, prefix="/api")
|
||||||
|
|
||||||
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
||||||
|
@@ -148,9 +148,10 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
|||||||
|
|
||||||
|
|
||||||
class OllamaAPI:
|
class OllamaAPI:
|
||||||
def __init__(self, rag: LightRAG):
|
def __init__(self, rag: LightRAG, top_k: int = 60):
|
||||||
self.rag = rag
|
self.rag = rag
|
||||||
self.ollama_server_infos = ollama_server_infos
|
self.ollama_server_infos = ollama_server_infos
|
||||||
|
self.top_k = top_k
|
||||||
self.router = APIRouter()
|
self.router = APIRouter()
|
||||||
self.setup_routes()
|
self.setup_routes()
|
||||||
|
|
||||||
@@ -381,7 +382,7 @@ class OllamaAPI:
|
|||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
"only_need_context": False,
|
"only_need_context": False,
|
||||||
"conversation_history": conversation_history,
|
"conversation_history": conversation_history,
|
||||||
"top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50,
|
"top_k": self.top_k,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
Reference in New Issue
Block a user