feat(lightrag): Implement mix search mode combining knowledge graph and vector retrieval

- Add 'mix' mode to QueryParam for hybrid search functionality
- Implement mix_kg_vector_query to combine knowledge graph and vector search results
- Update LightRAG class to handle 'mix' mode queries
- Enhance README with examples and explanations for the new mix search mode
- Introduce new prompt structure for generating responses based on combined search results
This commit is contained in:
Magic_yuan
2024-12-28 11:56:28 +08:00
parent e6b2f68e7c
commit aaaf617451
5 changed files with 305 additions and 2 deletions

View File

@@ -1147,3 +1147,195 @@ async def naive_query(
)
return response
async def mix_kg_vector_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
This function performs a hybrid search by:
1. Extracting semantic information from knowledge graph
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
"""
# 1. Cache handling
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash("mix", query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix"
)
if cached_response is not None:
return cached_response
# 2. Execute knowledge graph and vector searches in parallel
async def get_kg_context():
try:
# Reuse keyword extraction logic from kg_query
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(
PROMPTS["keywords_extraction_examples"]
):
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
# Extract keywords using LLM
kw_prompt = PROMPTS["keywords_extraction"].format(
query=query, examples=examples, language=language
)
result = await use_model_func(kw_prompt, keyword_extraction=True)
match = re.search(r"\{.*\}", result, re.DOTALL)
if not match:
logger.warning(
"No JSON-like structure found in keywords extraction result"
)
return None
result = match.group(0)
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
if not hl_keywords and not ll_keywords:
logger.warning("Both high-level and low-level keywords are empty")
return None
# Convert keyword lists to strings
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Set query mode based on available keywords
if not ll_keywords_str and not hl_keywords_str:
return None
elif not ll_keywords_str:
query_param.mode = "global"
elif not hl_keywords_str:
query_param.mode = "local"
else:
query_param.mode = "hybrid"
# Build knowledge graph context
context = await _build_query_context(
[ll_keywords_str, hl_keywords_str],
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
return context
except Exception as e:
logger.error(f"Error in get_kg_context: {str(e)}")
return None
async def get_vector_context():
# Reuse vector search logic from naive_query
try:
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
mix_topk = min(10, query_param.top_k)
results = await chunks_vdb.query(query, top_k=mix_topk)
if not results:
return None
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
valid_chunks = [
chunk for chunk in chunks if chunk is not None and "content" in chunk
]
if not valid_chunks:
return None
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
if not maybe_trun_chunks:
return None
return "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
except Exception as e:
logger.error(f"Error in get_vector_context: {e}")
return None
# 3. Execute both retrievals in parallel
kg_context, vector_context = await asyncio.gather(
get_kg_context(), get_vector_context()
)
# 4. Merge contexts
if kg_context is None and vector_context is None:
return PROMPTS["fail_response"]
if query_param.only_need_context:
return {"kg_context": kg_context, "vector_context": vector_context}
# 5. Construct hybrid prompt
sys_prompt = PROMPTS["mix_rag_response"].format(
kg_context=kg_context
if kg_context
else "No relevant knowledge graph information found",
vector_context=vector_context
if vector_context
else "No relevant text information found",
response_type=query_param.response_type,
)
if query_param.only_need_prompt:
return sys_prompt
# 6. Generate response
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# 7. Save cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode="mix",
),
)
return response