fix Ollama bugs

This commit is contained in:
LarFii
2024-10-17 16:02:43 +08:00
parent ccbb69cbdf
commit 7163f70924
2 changed files with 48 additions and 35 deletions

View File

@@ -144,7 +144,7 @@ rag = LightRAG(
</details> </details>
<details> <details>
<summary> Using Ollama Models (There are some bugs. I'll fix them ASAP.) </summary> <summary> Using Ollama Models </summary>
If you want to use Ollama models, you only need to set LightRAG as follows: If you want to use Ollama models, you only need to set LightRAG as follows:
```python ```python

View File

@@ -387,6 +387,7 @@ async def local_query(
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
) -> str: ) -> str:
context = None
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
@@ -399,7 +400,9 @@ async def local_query(
keywords = ', '.join(keywords) keywords = ', '.join(keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
try: try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", []) keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords) keywords = ', '.join(keywords)
@@ -407,13 +410,14 @@ async def local_query(
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
context = await _build_local_query_context( if keywords:
keywords, context = await _build_local_query_context(
knowledge_graph_inst, keywords,
entities_vdb, knowledge_graph_inst,
text_chunks_db, entities_vdb,
query_param, text_chunks_db,
) query_param,
)
if query_param.only_need_context: if query_param.only_need_context:
return context return context
if context is None: if context is None:
@@ -614,6 +618,7 @@ async def global_query(
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
) -> str: ) -> str:
context = None
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
@@ -626,7 +631,9 @@ async def global_query(
keywords = ', '.join(keywords) keywords = ', '.join(keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
try: try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", []) keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords) keywords = ', '.join(keywords)
@@ -635,15 +642,15 @@ async def global_query(
# Handle parsing error # Handle parsing error
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
if keywords:
context = await _build_global_query_context( context = await _build_global_query_context(
keywords, keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
if query_param.only_need_context: if query_param.only_need_context:
return context return context
@@ -836,6 +843,8 @@ async def hybrid_query(
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
) -> str: ) -> str:
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
@@ -850,7 +859,9 @@ async def hybrid_query(
ll_keywords = ', '.join(ll_keywords) ll_keywords = ', '.join(ll_keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
try: try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
keywords_data = json.loads(result) keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", [])
@@ -861,22 +872,24 @@ async def hybrid_query(
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
low_level_context = await _build_local_query_context( if ll_keywords:
ll_keywords, low_level_context = await _build_local_query_context(
knowledge_graph_inst, ll_keywords,
entities_vdb, knowledge_graph_inst,
text_chunks_db, entities_vdb,
query_param, text_chunks_db,
) query_param,
)
high_level_context = await _build_global_query_context( if hl_keywords:
hl_keywords, high_level_context = await _build_global_query_context(
knowledge_graph_inst, hl_keywords,
entities_vdb, knowledge_graph_inst,
relationships_vdb, entities_vdb,
text_chunks_db, relationships_vdb,
query_param, text_chunks_db,
) query_param,
)
context = combine_contexts(high_level_context, low_level_context) context = combine_contexts(high_level_context, low_level_context)