Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -13,14 +13,13 @@ import re
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import logging
|
||||
import argparse
|
||||
from typing import List, Any, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Literal, Optional, Dict
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.base import DocStatus, DocProcessingStatus
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import aiofiles
|
||||
@@ -637,71 +636,155 @@ class DocumentManager:
|
||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||
|
||||
|
||||
# LightRAG query mode
|
||||
class SearchMode(str, Enum):
|
||||
naive = "naive"
|
||||
local = "local"
|
||||
global_ = "global"
|
||||
hybrid = "hybrid"
|
||||
mix = "mix"
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
query: str = Field(
|
||||
min_length=1,
|
||||
description="The query text",
|
||||
)
|
||||
|
||||
"""Specifies the retrieval mode"""
|
||||
mode: SearchMode = SearchMode.hybrid
|
||||
mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
|
||||
default="hybrid",
|
||||
description="Query mode",
|
||||
)
|
||||
|
||||
"""If True, enables streaming output for real-time responses."""
|
||||
stream: Optional[bool] = None
|
||||
only_need_context: Optional[bool] = Field(
|
||||
default=None,
|
||||
description="If True, only returns the retrieved context without generating a response.",
|
||||
)
|
||||
|
||||
"""If True, only returns the retrieved context without generating a response."""
|
||||
only_need_context: Optional[bool] = None
|
||||
only_need_prompt: Optional[bool] = Field(
|
||||
default=None,
|
||||
description="If True, only returns the generated prompt without producing a response.",
|
||||
)
|
||||
|
||||
"""If True, only returns the generated prompt without producing a response."""
|
||||
only_need_prompt: Optional[bool] = None
|
||||
response_type: Optional[str] = Field(
|
||||
min_length=1,
|
||||
default=None,
|
||||
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
|
||||
)
|
||||
|
||||
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
||||
response_type: Optional[str] = None
|
||||
top_k: Optional[int] = Field(
|
||||
ge=1,
|
||||
default=None,
|
||||
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
|
||||
)
|
||||
|
||||
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||
top_k: Optional[int] = None
|
||||
max_token_for_text_unit: Optional[int] = Field(
|
||||
gt=1,
|
||||
default=None,
|
||||
description="Maximum number of tokens allowed for each retrieved text chunk.",
|
||||
)
|
||||
|
||||
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||
max_token_for_text_unit: Optional[int] = None
|
||||
max_token_for_global_context: Optional[int] = Field(
|
||||
gt=1,
|
||||
default=None,
|
||||
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
|
||||
)
|
||||
|
||||
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
||||
max_token_for_global_context: Optional[int] = None
|
||||
max_token_for_local_context: Optional[int] = Field(
|
||||
gt=1,
|
||||
default=None,
|
||||
description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
|
||||
)
|
||||
|
||||
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||
max_token_for_local_context: Optional[int] = None
|
||||
hl_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of high-level keywords to prioritize in retrieval.",
|
||||
)
|
||||
|
||||
"""List of high-level keywords to prioritize in retrieval."""
|
||||
hl_keywords: Optional[List[str]] = None
|
||||
ll_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of low-level keywords to refine retrieval focus.",
|
||||
)
|
||||
|
||||
"""List of low-level keywords to refine retrieval focus."""
|
||||
ll_keywords: Optional[List[str]] = None
|
||||
conversation_history: Optional[List[dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
|
||||
)
|
||||
|
||||
"""Stores past conversation history to maintain context.
|
||||
Format: [{"role": "user/assistant", "content": "message"}].
|
||||
"""
|
||||
conversation_history: Optional[List[dict[str, Any]]] = None
|
||||
history_turns: Optional[int] = Field(
|
||||
ge=0,
|
||||
default=None,
|
||||
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
|
||||
)
|
||||
|
||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||
history_turns: Optional[int] = None
|
||||
@field_validator("query", mode="after")
|
||||
@classmethod
|
||||
def query_strip_after(cls, query: str) -> str:
|
||||
return query.strip()
|
||||
|
||||
@field_validator("hl_keywords", mode="after")
|
||||
@classmethod
|
||||
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
|
||||
if hl_keywords is None:
|
||||
return None
|
||||
return [keyword.strip() for keyword in hl_keywords]
|
||||
|
||||
@field_validator("ll_keywords", mode="after")
|
||||
@classmethod
|
||||
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
|
||||
if ll_keywords is None:
|
||||
return None
|
||||
return [keyword.strip() for keyword in ll_keywords]
|
||||
|
||||
@field_validator("conversation_history", mode="after")
|
||||
@classmethod
|
||||
def conversation_history_role_check(
|
||||
cls, conversation_history: List[dict[str, Any]] | None
|
||||
) -> List[dict[str, Any]] | None:
|
||||
if conversation_history is None:
|
||||
return None
|
||||
for msg in conversation_history:
|
||||
if "role" not in msg or msg["role"] not in {"user", "assistant"}:
|
||||
raise ValueError(
|
||||
"Each message must have a 'role' key with value 'user' or 'assistant'."
|
||||
)
|
||||
return conversation_history
|
||||
|
||||
def to_query_params(self, is_stream: bool) -> QueryParam:
|
||||
"""Converts a QueryRequest instance into a QueryParam instance."""
|
||||
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
||||
request_data = self.model_dump(exclude_none=True, exclude={"query"})
|
||||
|
||||
# Ensure `mode` and `stream` are set explicitly
|
||||
param = QueryParam(**request_data)
|
||||
param.stream = is_stream
|
||||
return param
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
response: str
|
||||
response: str = Field(
|
||||
description="The generated response",
|
||||
)
|
||||
|
||||
|
||||
class InsertTextRequest(BaseModel):
|
||||
text: str
|
||||
text: str = Field(
|
||||
min_length=1,
|
||||
description="The text to insert",
|
||||
)
|
||||
|
||||
@field_validator("text", mode="after")
|
||||
@classmethod
|
||||
def strip_after(cls, text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
class InsertTextsRequest(BaseModel):
|
||||
texts: list[str] = Field(
|
||||
min_length=1,
|
||||
description="The texts to insert",
|
||||
)
|
||||
|
||||
@field_validator("texts", mode="after")
|
||||
@classmethod
|
||||
def strip_after(cls, texts: list[str]) -> list[str]:
|
||||
return [text.strip() for text in texts]
|
||||
|
||||
|
||||
class InsertResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
status: str = Field(description="Status of the operation")
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
|
||||
class DocStatusResponse(BaseModel):
|
||||
@@ -720,33 +803,6 @@ class DocsStatusesResponse(BaseModel):
|
||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
||||
|
||||
|
||||
def QueryRequestToQueryParams(request: QueryRequest):
|
||||
param = QueryParam(mode=request.mode, stream=request.stream)
|
||||
if request.only_need_context is not None:
|
||||
param.only_need_context = request.only_need_context
|
||||
if request.only_need_prompt is not None:
|
||||
param.only_need_prompt = request.only_need_prompt
|
||||
if request.response_type is not None:
|
||||
param.response_type = request.response_type
|
||||
if request.top_k is not None:
|
||||
param.top_k = request.top_k
|
||||
if request.max_token_for_text_unit is not None:
|
||||
param.max_token_for_text_unit = request.max_token_for_text_unit
|
||||
if request.max_token_for_global_context is not None:
|
||||
param.max_token_for_global_context = request.max_token_for_global_context
|
||||
if request.max_token_for_local_context is not None:
|
||||
param.max_token_for_local_context = request.max_token_for_local_context
|
||||
if request.hl_keywords is not None:
|
||||
param.hl_keywords = request.hl_keywords
|
||||
if request.ll_keywords is not None:
|
||||
param.ll_keywords = request.ll_keywords
|
||||
if request.conversation_history is not None:
|
||||
param.conversation_history = request.conversation_history
|
||||
if request.history_turns is not None:
|
||||
param.history_turns = request.history_turns
|
||||
return param
|
||||
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
if not api_key:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
@@ -1525,6 +1581,37 @@ def create_app(args):
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post(
|
||||
"/documents/texts",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
)
|
||||
async def insert_texts(
|
||||
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Insert texts into the Retrieval-Augmented Generation (RAG) system.
|
||||
|
||||
This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
|
||||
|
||||
Args:
|
||||
request (InsertTextsRequest): The request body containing the text to be inserted.
|
||||
background_tasks: FastAPI BackgroundTasks for async processing
|
||||
|
||||
Returns:
|
||||
InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
|
||||
"""
|
||||
try:
|
||||
background_tasks.add_task(pipeline_index_texts, request.texts)
|
||||
return InsertResponse(
|
||||
status="success",
|
||||
message="Text successfully received. Processing will continue in background.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error /documents/text: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post(
|
||||
"/documents/file",
|
||||
response_model=InsertResponse,
|
||||
@@ -1569,7 +1656,7 @@ def create_app(args):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post(
|
||||
"/documents/batch",
|
||||
"/documents/file_batch",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
)
|
||||
@@ -1673,20 +1760,14 @@ def create_app(args):
|
||||
"""
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
request.query, param=QueryRequestToQueryParams(request)
|
||||
request.query, param=request.to_query_params(False)
|
||||
)
|
||||
|
||||
# If response is a string (e.g. cache hit), return directly
|
||||
if isinstance(response, str):
|
||||
return QueryResponse(response=response)
|
||||
|
||||
# If it's an async generator, decide whether to stream based on stream parameter
|
||||
if request.stream or hasattr(response, "__aiter__"):
|
||||
result = ""
|
||||
async for chunk in response:
|
||||
result += chunk
|
||||
return QueryResponse(response=result)
|
||||
elif isinstance(response, dict):
|
||||
if isinstance(response, dict):
|
||||
result = json.dumps(response, indent=2)
|
||||
return QueryResponse(response=result)
|
||||
else:
|
||||
@@ -1708,11 +1789,8 @@ def create_app(args):
|
||||
StreamingResponse: A streaming response containing the RAG query results.
|
||||
"""
|
||||
try:
|
||||
params = QueryRequestToQueryParams(request)
|
||||
|
||||
params.stream = True
|
||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||
request.query, param=params
|
||||
response = await rag.aquery(
|
||||
request.query, param=request.to_query_params(True)
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -1738,7 +1816,7 @@ def create_app(args):
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
|
||||
"X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
lightrag/api/webui/assets/index-Cq9iD15S.css
Normal file
1
lightrag/api/webui/assets/index-Cq9iD15S.css
Normal file
File diff suppressed because one or more lines are too long
1098
lightrag/api/webui/assets/index-gr1CNi7P.js
Normal file
1098
lightrag/api/webui/assets/index-gr1CNi7P.js
Normal file
File diff suppressed because one or more lines are too long
@@ -5,8 +5,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="./vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lightrag</title>
|
||||
<script type="module" crossorigin src="./assets/index-BMB0OroL.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-CLgSwrjG.css">
|
||||
<script type="module" crossorigin src="./assets/index-gr1CNi7P.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-Cq9iD15S.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum
|
||||
from enum import Enum
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from dataclasses import dataclass, field
|
||||
@@ -205,7 +205,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||
|
||||
|
||||
class DocStatus(StrEnum):
|
||||
class DocStatus(str, Enum):
|
||||
"""Document processing status"""
|
||||
|
||||
PENDING = "pending"
|
||||
|
@@ -296,14 +296,16 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
if record and "edge_properties" in record:
|
||||
if record:
|
||||
try:
|
||||
result = dict(record["edge_properties"])
|
||||
logger.info(f"Result: {result}")
|
||||
# Ensure required keys exist with defaults
|
||||
required_keys = {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"target_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}
|
||||
for key, default_value in required_keys.items():
|
||||
if key not in result:
|
||||
@@ -323,20 +325,35 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
f"and {entity_name_label_target}: {str(e)}"
|
||||
)
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
return {
|
||||
"weight": 0.0,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
"source_id": None,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||
)
|
||||
# Return default edge properties when no edge found
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
return {
|
||||
"weight": 0.0,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
"source_id": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||
)
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
return {
|
||||
"weight": 0.0,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
"source_id": None,
|
||||
}
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
node_label = source_node_id.strip('"')
|
||||
|
@@ -674,7 +674,7 @@ class LightRAG:
|
||||
"content": content,
|
||||
"content_summary": self._get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
"status": DocStatus.PENDING.value,
|
||||
"status": DocStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
@@ -745,7 +745,7 @@ class LightRAG:
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSING.value,
|
||||
"status": DocStatus.PROCESSING,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
@@ -782,7 +782,7 @@ class LightRAG:
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.PROCESSED.value,
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
@@ -799,7 +799,7 @@ class LightRAG:
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_status_id: {
|
||||
"status": DocStatus.FAILED.value,
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
@@ -984,7 +984,10 @@ class LightRAG:
|
||||
await self._insert_done()
|
||||
|
||||
def query(
|
||||
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
system_prompt: str | None = None,
|
||||
) -> str | Iterator[str]:
|
||||
"""
|
||||
Perform a sync query.
|
||||
@@ -999,13 +1002,13 @@ class LightRAG:
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
|
||||
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
||||
return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore
|
||||
|
||||
async def aquery(
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
prompt: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Perform a async query.
|
||||
@@ -1037,7 +1040,7 @@ class LightRAG:
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
response = await naive_query(
|
||||
@@ -1056,6 +1059,7 @@ class LightRAG:
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
response = await mix_kg_vector_query(
|
||||
@@ -1077,6 +1081,7 @@ class LightRAG:
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
@@ -613,7 +613,7 @@ async def kg_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
prompt: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -677,7 +677,7 @@ async def kg_query(
|
||||
query_param.conversation_history, query_param.history_turns
|
||||
)
|
||||
|
||||
sys_prompt_temp = prompt if prompt else PROMPTS["rag_response"]
|
||||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context,
|
||||
response_type=query_param.response_type,
|
||||
@@ -828,6 +828,7 @@ async def mix_kg_vector_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
@@ -962,15 +963,19 @@ async def mix_kg_vector_query(
|
||||
return {"kg_context": kg_context, "vector_context": vector_context}
|
||||
|
||||
# 5. Construct hybrid prompt
|
||||
sys_prompt = PROMPTS["mix_rag_response"].format(
|
||||
kg_context=kg_context
|
||||
if kg_context
|
||||
else "No relevant knowledge graph information found",
|
||||
vector_context=vector_context
|
||||
if vector_context
|
||||
else "No relevant text information found",
|
||||
response_type=query_param.response_type,
|
||||
history=history_context,
|
||||
sys_prompt = (
|
||||
system_prompt
|
||||
if system_prompt
|
||||
else PROMPTS["mix_rag_response"].format(
|
||||
kg_context=kg_context
|
||||
if kg_context
|
||||
else "No relevant knowledge graph information found",
|
||||
vector_context=vector_context
|
||||
if vector_context
|
||||
else "No relevant text information found",
|
||||
response_type=query_param.response_type,
|
||||
history=history_context,
|
||||
)
|
||||
)
|
||||
|
||||
if query_param.only_need_prompt:
|
||||
@@ -1599,6 +1604,7 @@ async def naive_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -1651,7 +1657,7 @@ async def naive_query(
|
||||
query_param.conversation_history, query_param.history_turns
|
||||
)
|
||||
|
||||
sys_prompt_temp = PROMPTS["naive_rag_response"]
|
||||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
content_data=section,
|
||||
response_type=query_param.response_type,
|
||||
|
Reference in New Issue
Block a user