Add user prompt support for Ollama api
This commit is contained in:
@@ -101,10 +101,31 @@ def estimate_tokens(text: str) -> int:
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool]:
|
||||
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
|
||||
"""Parse query prefix to determine search mode
|
||||
Returns tuple of (cleaned_query, search_mode, only_need_context)
|
||||
Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt)
|
||||
|
||||
Examples:
|
||||
- "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams")
|
||||
- "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams")
|
||||
- "/local query string" -> (cleaned_query, SearchMode.local, False, None)
|
||||
"""
|
||||
# Initialize user_prompt as None
|
||||
user_prompt = None
|
||||
|
||||
# First check if there's a bracket format for user prompt
|
||||
bracket_pattern = r"^/([a-z]*)\[(.*?)\](.*)"
|
||||
bracket_match = re.match(bracket_pattern, query)
|
||||
|
||||
if bracket_match:
|
||||
mode_prefix = bracket_match.group(1)
|
||||
user_prompt = bracket_match.group(2)
|
||||
remaining_query = bracket_match.group(3).lstrip()
|
||||
|
||||
# Reconstruct query, removing the bracket part
|
||||
query = f"/{mode_prefix} {remaining_query}".strip()
|
||||
|
||||
# Unified handling of mode and only_need_context determination
|
||||
mode_map = {
|
||||
"/local ": (SearchMode.local, False),
|
||||
"/global ": (
|
||||
@@ -128,11 +149,11 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool]:
|
||||
|
||||
for prefix, (mode, only_need_context) in mode_map.items():
|
||||
if query.startswith(prefix):
|
||||
# After removing prefix an leading spaces
|
||||
# After removing prefix and leading spaces
|
||||
cleaned_query = query[len(prefix) :].lstrip()
|
||||
return cleaned_query, mode, only_need_context
|
||||
return cleaned_query, mode, only_need_context, user_prompt
|
||||
|
||||
return query, SearchMode.hybrid, False
|
||||
return query, SearchMode.hybrid, False, user_prompt
|
||||
|
||||
|
||||
class OllamaAPI:
|
||||
@@ -362,7 +383,9 @@ class OllamaAPI:
|
||||
]
|
||||
|
||||
# Check for query prefix
|
||||
cleaned_query, mode, only_need_context = parse_query_mode(query)
|
||||
cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(
|
||||
query
|
||||
)
|
||||
|
||||
start_time = time.time_ns()
|
||||
prompt_tokens = estimate_tokens(cleaned_query)
|
||||
@@ -375,6 +398,10 @@ class OllamaAPI:
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
# Add user_prompt to param_dict
|
||||
if user_prompt is not None:
|
||||
param_dict["user_prompt"] = user_prompt
|
||||
|
||||
if (
|
||||
hasattr(self.rag, "args")
|
||||
and self.rag.args.history_turns is not None
|
||||
@@ -524,7 +551,7 @@ class OllamaAPI:
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
|
||||
"X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user