添加字符分割功能,在“insert”函数中如果增加参数split_by_character,则会按照split_by_character进行字符分割,此时如果每个分割后的chunk的tokens大于max_token_size,则会继续按token_size分割(todo:考虑字符分割后过短的chunk处理)

This commit is contained in:
童石渊
2025-01-07 00:28:15 +08:00
parent 39a366a3dc
commit 536d6f2283
2 changed files with 171 additions and 146 deletions

View File

@@ -45,6 +45,7 @@ from .storage import (
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
# future KG integrations # future KG integrations
# from .kg.ArangoDB_impl import ( # from .kg.ArangoDB_impl import (
@@ -167,7 +168,7 @@ class LightRAG:
# LLM # LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768 llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16 llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict) llm_model_kwargs: dict = field(default_factory=dict)
@@ -267,7 +268,7 @@ class LightRAG:
self.llm_model_func, self.llm_model_func,
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
@@ -313,15 +314,16 @@ class LightRAG:
"JsonDocStatusStorage": JsonDocStatusStorage, "JsonDocStatusStorage": JsonDocStatusStorage,
} }
def insert(self, string_or_strings): def insert(self, string_or_strings, split_by_character=None):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings)) return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character))
async def ainsert(self, string_or_strings): async def ainsert(self, string_or_strings, split_by_character):
"""Insert documents with checkpoint support """Insert documents with checkpoint support
Args: Args:
string_or_strings: Single document string or list of document strings string_or_strings: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character
""" """
if isinstance(string_or_strings, str): if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings] string_or_strings = [string_or_strings]
@@ -355,10 +357,10 @@ class LightRAG:
# Process documents in batches # Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10) batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size): for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size]) batch_docs = dict(list(new_docs.items())[i: i + batch_size])
for doc_id, doc in tqdm_async( for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}" batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
): ):
try: try:
# Update status to processing # Update status to processing
@@ -379,6 +381,7 @@ class LightRAG:
} }
for dp in chunking_by_token_size( for dp in chunking_by_token_size(
doc["content"], doc["content"],
split_by_character=split_by_character,
overlap_token_size=self.chunk_overlap_token_size, overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size, max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name, tiktoken_model=self.tiktoken_model_name,
@@ -545,7 +548,7 @@ class LightRAG:
# Check if nodes exist in the knowledge graph # Check if nodes exist in the knowledge graph
for need_insert_id in [src_id, tgt_id]: for need_insert_id in [src_id, tgt_id]:
if not ( if not (
await self.chunk_entity_relation_graph.has_node(need_insert_id) await self.chunk_entity_relation_graph.has_node(need_insert_id)
): ):
await self.chunk_entity_relation_graph.upsert_node( await self.chunk_entity_relation_graph.upsert_node(
need_insert_id, need_insert_id,
@@ -594,9 +597,9 @@ class LightRAG:
"src_id": dp["src_id"], "src_id": dp["src_id"],
"tgt_id": dp["tgt_id"], "tgt_id": dp["tgt_id"],
"content": dp["keywords"] "content": dp["keywords"]
+ dp["src_id"] + dp["src_id"]
+ dp["tgt_id"] + dp["tgt_id"]
+ dp["description"], + dp["description"],
} }
for dp in all_relationships_data for dp in all_relationships_data
} }
@@ -621,7 +624,7 @@ class LightRAG:
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
@@ -637,7 +640,7 @@ class LightRAG:
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
@@ -656,7 +659,7 @@ class LightRAG:
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
@@ -897,7 +900,7 @@ class LightRAG:
dp dp
for dp in self.entities_vdb.client_storage["data"] for dp in self.entities_vdb.client_storage["data"]
if chunk_id if chunk_id
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
] ]
if entities_with_chunk: if entities_with_chunk:
logger.error( logger.error(
@@ -909,7 +912,7 @@ class LightRAG:
dp dp
for dp in self.relationships_vdb.client_storage["data"] for dp in self.relationships_vdb.client_storage["data"]
if chunk_id if chunk_id
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
] ]
if relations_with_chunk: if relations_with_chunk:
logger.error( logger.error(
@@ -926,7 +929,7 @@ class LightRAG:
return asyncio.run(self.adelete_by_doc_id(doc_id)) return asyncio.run(self.adelete_by_doc_id(doc_id))
async def get_entity_info( async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False self, entity_name: str, include_vector_data: bool = False
): ):
"""Get detailed information of an entity """Get detailed information of an entity
@@ -977,7 +980,7 @@ class LightRAG:
tracemalloc.stop() tracemalloc.stop()
async def get_relation_info( async def get_relation_info(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
): ):
"""Get detailed information of a relationship """Get detailed information of a relationship
@@ -1019,7 +1022,7 @@ class LightRAG:
return result return result
def get_relation_info_sync( def get_relation_info_sync(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
): ):
"""Synchronous version of getting relationship information """Synchronous version of getting relationship information

View File

@@ -34,30 +34,52 @@ import time
def chunking_by_token_size( def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
): ):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = [] results = []
for index, start in enumerate( if split_by_character:
range(0, len(tokens), max_token_size - overlap_token_size) raw_chunks = content.split(split_by_character)
): new_chunks = []
chunk_content = decode_tokens_by_tiktoken( for chunk in raw_chunks:
tokens[start : start + max_token_size], model_name=tiktoken_model _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
) if len(_tokens) > max_token_size:
results.append( for start in range(0, len(_tokens), max_token_size - overlap_token_size):
{ chunk_content = decode_tokens_by_tiktoken(
"tokens": min(max_token_size, len(tokens) - start), _tokens[start: start + max_token_size], model_name=tiktoken_model
"content": chunk_content.strip(), )
"chunk_order_index": index, new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content))
} else:
) new_chunks.append((len(_tokens), chunk))
for index, (_len, chunk) in enumerate(new_chunks):
results.append(
{
"tokens": _len,
"content": chunk.strip(),
"chunk_order_index": index,
}
)
else:
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start: start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
return results return results
async def _handle_entity_relation_summary( async def _handle_entity_relation_summary(
entity_or_relation_name: str, entity_or_relation_name: str,
description: str, description: str,
global_config: dict, global_config: dict,
) -> str: ) -> str:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
llm_max_tokens = global_config["llm_model_max_token_size"] llm_max_tokens = global_config["llm_model_max_token_size"]
@@ -86,8 +108,8 @@ async def _handle_entity_relation_summary(
async def _handle_single_entity_extraction( async def _handle_single_entity_extraction(
record_attributes: list[str], record_attributes: list[str],
chunk_key: str, chunk_key: str,
): ):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"': if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None return None
@@ -107,8 +129,8 @@ async def _handle_single_entity_extraction(
async def _handle_single_relationship_extraction( async def _handle_single_relationship_extraction(
record_attributes: list[str], record_attributes: list[str],
chunk_key: str, chunk_key: str,
): ):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None return None
@@ -134,10 +156,10 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert( async def _merge_nodes_then_upsert(
entity_name: str, entity_name: str,
nodes_data: list[dict], nodes_data: list[dict],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict, global_config: dict,
): ):
already_entity_types = [] already_entity_types = []
already_source_ids = [] already_source_ids = []
@@ -181,11 +203,11 @@ async def _merge_nodes_then_upsert(
async def _merge_edges_then_upsert( async def _merge_edges_then_upsert(
src_id: str, src_id: str,
tgt_id: str, tgt_id: str,
edges_data: list[dict], edges_data: list[dict],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict, global_config: dict,
): ):
already_weights = [] already_weights = []
already_source_ids = [] already_source_ids = []
@@ -248,12 +270,12 @@ async def _merge_edges_then_upsert(
async def extract_entities( async def extract_entities(
chunks: dict[str, TextChunkSchema], chunks: dict[str, TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage, entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
global_config: dict, global_config: dict,
llm_response_cache: BaseKVStorage = None, llm_response_cache: BaseKVStorage = None,
) -> Union[BaseGraphStorage, None]: ) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -305,13 +327,13 @@ async def extract_entities(
already_relations = 0 already_relations = 0
async def _user_llm_func_with_cache( async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None input_text: str, history_messages: list[dict[str, str]] = None
) -> str: ) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache: if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False need_to_restore = False
if ( if (
global_config["embedding_cache_config"] global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"] and global_config["embedding_cache_config"]["enabled"]
): ):
new_config = global_config.copy() new_config = global_config.copy()
new_config["embedding_cache_config"] = None new_config["embedding_cache_config"] = None
@@ -413,7 +435,7 @@ async def extract_entities(
already_relations += len(maybe_edges) already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][ now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"]) already_processed % len(PROMPTS["process_tickers"])
] ]
print( print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="", end="",
@@ -423,10 +445,10 @@ async def extract_entities(
results = [] results = []
for result in tqdm_async( for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks), total=len(ordered_chunks),
desc="Extracting entities from chunks", desc="Extracting entities from chunks",
unit="chunk", unit="chunk",
): ):
results.append(await result) results.append(await result)
@@ -440,32 +462,32 @@ async def extract_entities(
logger.info("Inserting entities into storage...") logger.info("Inserting entities into storage...")
all_entities_data = [] all_entities_data = []
for result in tqdm_async( for result in tqdm_async(
asyncio.as_completed( asyncio.as_completed(
[ [
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items() for k, v in maybe_nodes.items()
] ]
), ),
total=len(maybe_nodes), total=len(maybe_nodes),
desc="Inserting entities", desc="Inserting entities",
unit="entity", unit="entity",
): ):
all_entities_data.append(await result) all_entities_data.append(await result)
logger.info("Inserting relationships into storage...") logger.info("Inserting relationships into storage...")
all_relationships_data = [] all_relationships_data = []
for result in tqdm_async( for result in tqdm_async(
asyncio.as_completed( asyncio.as_completed(
[ [
_merge_edges_then_upsert( _merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config k[0], k[1], v, knowledge_graph_inst, global_config
) )
for k, v in maybe_edges.items() for k, v in maybe_edges.items()
] ]
), ),
total=len(maybe_edges), total=len(maybe_edges),
desc="Inserting relationships", desc="Inserting relationships",
unit="relationship", unit="relationship",
): ):
all_relationships_data.append(await result) all_relationships_data.append(await result)
@@ -496,9 +518,9 @@ async def extract_entities(
"src_id": dp["src_id"], "src_id": dp["src_id"],
"tgt_id": dp["tgt_id"], "tgt_id": dp["tgt_id"],
"content": dp["keywords"] "content": dp["keywords"]
+ dp["src_id"] + dp["src_id"]
+ dp["tgt_id"] + dp["tgt_id"]
+ dp["description"], + dp["description"],
"metadata": { "metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time()) "created_at": dp.get("metadata", {}).get("created_at", time.time())
}, },
@@ -511,14 +533,14 @@ async def extract_entities(
async def kg_query( async def kg_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage = None,
) -> str: ) -> str:
# Handle cache # Handle cache
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
@@ -638,12 +660,12 @@ async def kg_query(
async def _build_query_context( async def _build_query_context(
query: list, query: list,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" # ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", "" # hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
@@ -696,9 +718,9 @@ async def _build_query_context(
query_param, query_param,
) )
if ( if (
hl_entities_context == "" hl_entities_context == ""
and hl_relations_context == "" and hl_relations_context == ""
and hl_text_units_context == "" and hl_text_units_context == ""
): ):
logger.warn("No high level context found. Switching to local mode.") logger.warn("No high level context found. Switching to local mode.")
query_param.mode = "local" query_param.mode = "local"
@@ -737,11 +759,11 @@ async def _build_query_context(
async def _get_node_data( async def _get_node_data(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
# get similar entities # get similar entities
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
@@ -828,10 +850,10 @@ async def _get_node_data(
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
node_datas: list[dict], node_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -871,8 +893,8 @@ async def _find_most_related_text_unit_from_entities(
if this_edges: if this_edges:
for e in this_edges: for e in this_edges:
if ( if (
e[1] in all_one_hop_text_units_lookup e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]] and c_id in all_one_hop_text_units_lookup[e[1]]
): ):
all_text_units_lookup[c_id]["relation_counts"] += 1 all_text_units_lookup[c_id]["relation_counts"] += 1
@@ -902,9 +924,9 @@ async def _find_most_related_text_unit_from_entities(
async def _find_most_related_edges_from_entities( async def _find_most_related_edges_from_entities(
node_datas: list[dict], node_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
): ):
all_related_edges = await asyncio.gather( all_related_edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
@@ -942,11 +964,11 @@ async def _find_most_related_edges_from_entities(
async def _get_edge_data( async def _get_edge_data(
keywords, keywords,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k) results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
@@ -1044,9 +1066,9 @@ async def _get_edge_data(
async def _find_most_related_entities_from_relationships( async def _find_most_related_entities_from_relationships(
edge_datas: list[dict], edge_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
): ):
entity_names = [] entity_names = []
seen = set() seen = set()
@@ -1081,10 +1103,10 @@ async def _find_most_related_entities_from_relationships(
async def _find_related_text_unit_from_relationships( async def _find_related_text_unit_from_relationships(
edge_datas: list[dict], edge_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1150,12 +1172,12 @@ def combine_contexts(entities, relationships, sources):
async def naive_query( async def naive_query(
query, query,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage = None,
): ):
# Handle cache # Handle cache
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
@@ -1213,7 +1235,7 @@ async def naive_query(
if len(response) > len(sys_prompt): if len(response) > len(sys_prompt):
response = ( response = (
response[len(sys_prompt) :] response[len(sys_prompt):]
.replace(sys_prompt, "") .replace(sys_prompt, "")
.replace("user", "") .replace("user", "")
.replace("model", "") .replace("model", "")
@@ -1241,15 +1263,15 @@ async def naive_query(
async def mix_kg_vector_query( async def mix_kg_vector_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage = None,
) -> str: ) -> str:
""" """
Hybrid retrieval implementation combining knowledge graph and vector search. Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1274,7 +1296,7 @@ async def mix_kg_vector_query(
# Reuse keyword extraction logic from kg_query # Reuse keyword extraction logic from kg_query
example_number = global_config["addon_params"].get("example_number", None) example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len( if example_number and example_number < len(
PROMPTS["keywords_extraction_examples"] PROMPTS["keywords_extraction_examples"]
): ):
examples = "\n".join( examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)] PROMPTS["keywords_extraction_examples"][: int(example_number)]