Add user prompt support for Ollama api

This commit is contained in:
yangdx
2025-05-09 11:37:43 +08:00
parent 4e1caf1e40
commit fb4f12ba8e

View File

@@ -101,10 +101,31 @@ def estimate_tokens(text: str) -> int:
return len(tokens)
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool]:
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode, only_need_context)
Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt)
Examples:
- "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams")
- "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams")
- "/local query string" -> (cleaned_query, SearchMode.local, False, None)
"""
# Initialize user_prompt as None
user_prompt = None
# First check if there's a bracket format for user prompt
bracket_pattern = r"^/([a-z]*)\[(.*?)\](.*)"
bracket_match = re.match(bracket_pattern, query)
if bracket_match:
mode_prefix = bracket_match.group(1)
user_prompt = bracket_match.group(2)
remaining_query = bracket_match.group(3).lstrip()
# Reconstruct query, removing the bracket part
query = f"/{mode_prefix} {remaining_query}".strip()
# Unified handling of mode and only_need_context determination
mode_map = {
"/local ": (SearchMode.local, False),
"/global ": (
@@ -128,11 +149,11 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool]:
for prefix, (mode, only_need_context) in mode_map.items():
if query.startswith(prefix):
# After removing prefix an leading spaces
# After removing prefix and leading spaces
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode, only_need_context
return cleaned_query, mode, only_need_context, user_prompt
return query, SearchMode.hybrid, False
return query, SearchMode.hybrid, False, user_prompt
class OllamaAPI:
@@ -362,7 +383,9 @@ class OllamaAPI:
]
# Check for query prefix
cleaned_query, mode, only_need_context = parse_query_mode(query)
cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(
query
)
start_time = time.time_ns()
prompt_tokens = estimate_tokens(cleaned_query)
@@ -375,6 +398,10 @@ class OllamaAPI:
"top_k": self.top_k,
}
# Add user_prompt to param_dict
if user_prompt is not None:
param_dict["user_prompt"] = user_prompt
if (
hasattr(self.rag, "args")
and self.rag.args.history_turns is not None
@@ -524,7 +551,7 @@ class OllamaAPI:
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
"X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
},
)
else: