diff --git a/README.md b/README.md index 38f936e1..6e8d6507 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,38 @@ response = rag.query( ) ``` +### Custom Prompt Support +LightRAG now supports custom prompts for fine-tuned control over the system's behavior. Here's how to use it: +```python +from lightrag import LightRAG, QueryParam + +# Initialize LightRAG +rag = LightRAG(working_dir=WORKING_DIR) + +# Create query parameters +query_param = QueryParam( + mode="hybrid", # or other mode: "local", "global", "hybrid" +) + +# Example 1: Using the default system prompt +response_default = rag.query( + "What are the primary benefits of renewable energy?", + param=query_param +) +print(response_default) + +# Example 2: Using a custom prompt +custom_prompt = """ +You are an expert assistant in environmental science. Provide detailed and structured answers with examples. +""" +response_custom = rag.query( + "What are the primary benefits of renewable energy?", + param=query_param, + prompt=custom_prompt # Pass the custom prompt +) +print(response_custom) +```
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0918c1ba..7e8a3bb7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -892,11 +892,13 @@ class LightRAG: if update_storage: await self._insert_done() - def query(self, query: str, param: QueryParam = QueryParam()): + def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()): loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, param)) + return loop.run_until_complete(self.aquery(query, prompt, param)) - async def aquery(self, query: str, param: QueryParam = QueryParam()): + async def aquery( + self, query: str, prompt: str = "", param: QueryParam = QueryParam() + ): if param.mode in ["local", "global", "hybrid"]: response = await kg_query( query, @@ -914,6 +916,7 @@ class LightRAG: global_config=asdict(self), embedding_func=None, ), + prompt=prompt, ) elif param.mode == "naive": response = await naive_query( diff --git a/lightrag/operate.py b/lightrag/operate.py index 0469fb7e..10c76bcc 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -574,6 +574,7 @@ async def kg_query( query_param: QueryParam, global_config: dict, hashing_kv: BaseKVStorage = None, + prompt: str = "", ) -> str: # Handle cache use_model_func = global_config["llm_model_func"] @@ -637,7 +638,7 @@ async def kg_query( query_param.conversation_history, query_param.history_turns ) - sys_prompt_temp = PROMPTS["rag_response"] + sys_prompt_temp = prompt if prompt else PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type,