修复bug

https://github.com/HKUDS/LightRAG/issues/306
主要修改包括:
在存储文本块数据时增加了验证,确保只存储有效的数据
在处理文本块之前增加了空列表检查
在截断文本块之前过滤掉无效的数据
增加了更多的日志警告信息
查询的修改:
添加了对 chunks 的有效性检查,过滤掉无效的 chunks:
This commit is contained in:
Magic_yuan
2024-12-09 15:08:30 +08:00
parent d8edc915e7
commit 865e76a083

View File

@@ -990,23 +990,37 @@ async def _find_related_text_unit_from_relationships(
for index, unit_list in enumerate(text_units): for index, unit_list in enumerate(text_units):
for c_id in unit_list: for c_id in unit_list:
if c_id not in all_text_units_lookup: if c_id not in all_text_units_lookup:
all_text_units_lookup[c_id] = { chunk_data = await text_chunks_db.get_by_id(c_id)
"data": await text_chunks_db.get_by_id(c_id), # Only store valid data
"order": index, if chunk_data is not None and "content" in chunk_data:
} all_text_units_lookup[c_id] = {
"data": chunk_data,
"order": index,
}
if any([v is None for v in all_text_units_lookup.values()]): if not all_text_units_lookup:
logger.warning("Text chunks are missing, maybe the storage is damaged") logger.warning("No valid text chunks found")
all_text_units = [ return []
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
] all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()]
all_text_units = sorted(all_text_units, key=lambda x: x["order"]) all_text_units = sorted(all_text_units, key=lambda x: x["order"])
all_text_units = truncate_list_by_token_size(
all_text_units, # Ensure all text chunks have content
valid_text_units = [
t for t in all_text_units if t["data"] is not None and "content" in t["data"]
]
if not valid_text_units:
logger.warning("No valid text chunks after filtering")
return []
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"], key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
return all_text_units return all_text_units
@@ -1050,24 +1064,43 @@ async def naive_query(
results = await chunks_vdb.query(query, top_k=query_param.top_k) results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results] chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids) chunks = await text_chunks_db.get_by_ids(chunks_ids)
# Filter out invalid chunks
valid_chunks = [
chunk for chunk in chunks if chunk is not None and "content" in chunk
]
if not valid_chunks:
logger.warning("No valid chunks found after filtering")
return PROMPTS["fail_response"]
maybe_trun_chunks = truncate_list_by_token_size( maybe_trun_chunks = truncate_list_by_token_size(
chunks, valid_chunks,
key=lambda x: x["content"], key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
if not maybe_trun_chunks:
logger.warning("No chunks left after truncation")
return PROMPTS["fail_response"]
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context: if query_param.only_need_context:
return section return section
sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type content_data=section, response_type=query_param.response_type
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,