diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 3aabfe35..93394fbb 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -101,10 +101,31 @@ def estimate_tokens(text: str) -> int: return len(tokens) -def parse_query_mode(query: str) -> tuple[str, SearchMode, bool]: +def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]: """Parse query prefix to determine search mode - Returns tuple of (cleaned_query, search_mode, only_need_context) + Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt) + + Examples: + - "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams") + - "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams") + - "/local query string" -> (cleaned_query, SearchMode.local, False, None) """ + # Initialize user_prompt as None + user_prompt = None + + # First check if there's a bracket format for user prompt + bracket_pattern = r"^/([a-z]*)\[(.*?)\](.*)" + bracket_match = re.match(bracket_pattern, query) + + if bracket_match: + mode_prefix = bracket_match.group(1) + user_prompt = bracket_match.group(2) + remaining_query = bracket_match.group(3).lstrip() + + # Reconstruct query, removing the bracket part + query = f"/{mode_prefix} {remaining_query}".strip() + + # Unified handling of mode and only_need_context determination mode_map = { "/local ": (SearchMode.local, False), "/global ": ( @@ -128,11 +149,11 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool]: for prefix, (mode, only_need_context) in mode_map.items(): if query.startswith(prefix): - # After removing prefix an leading spaces + # After removing prefix and leading spaces cleaned_query = query[len(prefix) :].lstrip() - return cleaned_query, mode, only_need_context + return cleaned_query, mode, only_need_context, user_prompt - return query, SearchMode.hybrid, False + return query, SearchMode.hybrid, False, user_prompt class OllamaAPI: @@ -362,7 +383,9 @@ class OllamaAPI: ] # Check for query prefix - cleaned_query, mode, only_need_context = parse_query_mode(query) + cleaned_query, mode, only_need_context, user_prompt = parse_query_mode( + query + ) start_time = time.time_ns() prompt_tokens = estimate_tokens(cleaned_query) @@ -375,6 +398,10 @@ class OllamaAPI: "top_k": self.top_k, } + # Add user_prompt to param_dict + if user_prompt is not None: + param_dict["user_prompt"] = user_prompt + if ( hasattr(self.rag, "args") and self.rag.args.history_turns is not None @@ -524,7 +551,7 @@ class OllamaAPI: "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 + "X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy }, ) else: