添加字符分割功能,在“insert”函数中如果增加参数split_by_character,则会按照split_by_character进行字符分割,此时如果每个分割后的chunk的tokens大于max_token_size,则会继续按token_size分割(todo:考虑字符分割后过短的chunk处理)
This commit is contained in:
@@ -45,6 +45,7 @@ from .storage import (
|
||||
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
|
||||
|
||||
# future KG integrations
|
||||
|
||||
# from .kg.ArangoDB_impl import (
|
||||
@@ -167,7 +168,7 @@ class LightRAG:
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
llm_model_kwargs: dict = field(default_factory=dict)
|
||||
@@ -267,7 +268,7 @@ class LightRAG:
|
||||
self.llm_model_func,
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -313,15 +314,16 @@ class LightRAG:
|
||||
"JsonDocStatusStorage": JsonDocStatusStorage,
|
||||
}
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
def insert(self, string_or_strings, split_by_character=None):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character))
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
async def ainsert(self, string_or_strings, split_by_character):
|
||||
"""Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
@@ -355,10 +357,10 @@ class LightRAG:
|
||||
# Process documents in batches
|
||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||
for i in range(0, len(new_docs), batch_size):
|
||||
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
||||
batch_docs = dict(list(new_docs.items())[i: i + batch_size])
|
||||
|
||||
for doc_id, doc in tqdm_async(
|
||||
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
|
||||
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
|
||||
):
|
||||
try:
|
||||
# Update status to processing
|
||||
@@ -379,6 +381,7 @@ class LightRAG:
|
||||
}
|
||||
for dp in chunking_by_token_size(
|
||||
doc["content"],
|
||||
split_by_character=split_by_character,
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
@@ -545,7 +548,7 @@ class LightRAG:
|
||||
# Check if nodes exist in the knowledge graph
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (
|
||||
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
||||
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
||||
):
|
||||
await self.chunk_entity_relation_graph.upsert_node(
|
||||
need_insert_id,
|
||||
@@ -594,9 +597,9 @@ class LightRAG:
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
@@ -621,7 +624,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -637,7 +640,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -656,7 +659,7 @@ class LightRAG:
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
@@ -897,7 +900,7 @@ class LightRAG:
|
||||
dp
|
||||
for dp in self.entities_vdb.client_storage["data"]
|
||||
if chunk_id
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
]
|
||||
if entities_with_chunk:
|
||||
logger.error(
|
||||
@@ -909,7 +912,7 @@ class LightRAG:
|
||||
dp
|
||||
for dp in self.relationships_vdb.client_storage["data"]
|
||||
if chunk_id
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||
]
|
||||
if relations_with_chunk:
|
||||
logger.error(
|
||||
@@ -926,7 +929,7 @@ class LightRAG:
|
||||
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
||||
|
||||
async def get_entity_info(
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Get detailed information of an entity
|
||||
|
||||
@@ -977,7 +980,7 @@ class LightRAG:
|
||||
tracemalloc.stop()
|
||||
|
||||
async def get_relation_info(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Get detailed information of a relationship
|
||||
|
||||
@@ -1019,7 +1022,7 @@ class LightRAG:
|
||||
return result
|
||||
|
||||
def get_relation_info_sync(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Synchronous version of getting relationship information
|
||||
|
||||
|
@@ -34,30 +34,52 @@ import time
|
||||
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
):
|
||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||
results = []
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"tokens": min(max_token_size, len(tokens) - start),
|
||||
"content": chunk_content.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
if split_by_character:
|
||||
raw_chunks = content.split(split_by_character)
|
||||
new_chunks = []
|
||||
for chunk in raw_chunks:
|
||||
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||||
if len(_tokens) > max_token_size:
|
||||
for start in range(0, len(_tokens), max_token_size - overlap_token_size):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
_tokens[start: start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content))
|
||||
else:
|
||||
new_chunks.append((len(_tokens), chunk))
|
||||
for index, (_len, chunk) in enumerate(new_chunks):
|
||||
results.append(
|
||||
{
|
||||
"tokens": _len,
|
||||
"content": chunk.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
else:
|
||||
for index, start in enumerate(
|
||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||
):
|
||||
chunk_content = decode_tokens_by_tiktoken(
|
||||
tokens[start: start + max_token_size], model_name=tiktoken_model
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"tokens": min(max_token_size, len(tokens) - start),
|
||||
"content": chunk_content.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _handle_entity_relation_summary(
|
||||
entity_or_relation_name: str,
|
||||
description: str,
|
||||
global_config: dict,
|
||||
entity_or_relation_name: str,
|
||||
description: str,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||
@@ -86,8 +108,8 @@ async def _handle_entity_relation_summary(
|
||||
|
||||
|
||||
async def _handle_single_entity_extraction(
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
):
|
||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||
return None
|
||||
@@ -107,8 +129,8 @@ async def _handle_single_entity_extraction(
|
||||
|
||||
|
||||
async def _handle_single_relationship_extraction(
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
):
|
||||
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
||||
return None
|
||||
@@ -134,10 +156,10 @@ async def _handle_single_relationship_extraction(
|
||||
|
||||
|
||||
async def _merge_nodes_then_upsert(
|
||||
entity_name: str,
|
||||
nodes_data: list[dict],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
entity_name: str,
|
||||
nodes_data: list[dict],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
):
|
||||
already_entity_types = []
|
||||
already_source_ids = []
|
||||
@@ -181,11 +203,11 @@ async def _merge_nodes_then_upsert(
|
||||
|
||||
|
||||
async def _merge_edges_then_upsert(
|
||||
src_id: str,
|
||||
tgt_id: str,
|
||||
edges_data: list[dict],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
src_id: str,
|
||||
tgt_id: str,
|
||||
edges_data: list[dict],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
):
|
||||
already_weights = []
|
||||
already_source_ids = []
|
||||
@@ -248,12 +270,12 @@ async def _merge_edges_then_upsert(
|
||||
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage = None,
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage = None,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
@@ -305,13 +327,13 @@ async def extract_entities(
|
||||
already_relations = 0
|
||||
|
||||
async def _user_llm_func_with_cache(
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
) -> str:
|
||||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||||
need_to_restore = False
|
||||
if (
|
||||
global_config["embedding_cache_config"]
|
||||
and global_config["embedding_cache_config"]["enabled"]
|
||||
global_config["embedding_cache_config"]
|
||||
and global_config["embedding_cache_config"]["enabled"]
|
||||
):
|
||||
new_config = global_config.copy()
|
||||
new_config["embedding_cache_config"] = None
|
||||
@@ -413,7 +435,7 @@ async def extract_entities(
|
||||
already_relations += len(maybe_edges)
|
||||
now_ticks = PROMPTS["process_tickers"][
|
||||
already_processed % len(PROMPTS["process_tickers"])
|
||||
]
|
||||
]
|
||||
print(
|
||||
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
||||
end="",
|
||||
@@ -423,10 +445,10 @@ async def extract_entities(
|
||||
|
||||
results = []
|
||||
for result in tqdm_async(
|
||||
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
|
||||
total=len(ordered_chunks),
|
||||
desc="Extracting entities from chunks",
|
||||
unit="chunk",
|
||||
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
|
||||
total=len(ordered_chunks),
|
||||
desc="Extracting entities from chunks",
|
||||
unit="chunk",
|
||||
):
|
||||
results.append(await result)
|
||||
|
||||
@@ -440,32 +462,32 @@ async def extract_entities(
|
||||
logger.info("Inserting entities into storage...")
|
||||
all_entities_data = []
|
||||
for result in tqdm_async(
|
||||
asyncio.as_completed(
|
||||
[
|
||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
),
|
||||
total=len(maybe_nodes),
|
||||
desc="Inserting entities",
|
||||
unit="entity",
|
||||
asyncio.as_completed(
|
||||
[
|
||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
),
|
||||
total=len(maybe_nodes),
|
||||
desc="Inserting entities",
|
||||
unit="entity",
|
||||
):
|
||||
all_entities_data.append(await result)
|
||||
|
||||
logger.info("Inserting relationships into storage...")
|
||||
all_relationships_data = []
|
||||
for result in tqdm_async(
|
||||
asyncio.as_completed(
|
||||
[
|
||||
_merge_edges_then_upsert(
|
||||
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||
)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
),
|
||||
total=len(maybe_edges),
|
||||
desc="Inserting relationships",
|
||||
unit="relationship",
|
||||
asyncio.as_completed(
|
||||
[
|
||||
_merge_edges_then_upsert(
|
||||
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||
)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
),
|
||||
total=len(maybe_edges),
|
||||
desc="Inserting relationships",
|
||||
unit="relationship",
|
||||
):
|
||||
all_relationships_data.append(await result)
|
||||
|
||||
@@ -496,9 +518,9 @@ async def extract_entities(
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
"metadata": {
|
||||
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
||||
},
|
||||
@@ -511,14 +533,14 @@ async def extract_entities(
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -638,12 +660,12 @@ async def kg_query(
|
||||
|
||||
|
||||
async def _build_query_context(
|
||||
query: list,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
query: list,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
|
||||
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
|
||||
@@ -696,9 +718,9 @@ async def _build_query_context(
|
||||
query_param,
|
||||
)
|
||||
if (
|
||||
hl_entities_context == ""
|
||||
and hl_relations_context == ""
|
||||
and hl_text_units_context == ""
|
||||
hl_entities_context == ""
|
||||
and hl_relations_context == ""
|
||||
and hl_text_units_context == ""
|
||||
):
|
||||
logger.warn("No high level context found. Switching to local mode.")
|
||||
query_param.mode = "local"
|
||||
@@ -737,11 +759,11 @@ async def _build_query_context(
|
||||
|
||||
|
||||
async def _get_node_data(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
# get similar entities
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||
@@ -828,10 +850,10 @@ async def _get_node_data(
|
||||
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
@@ -871,8 +893,8 @@ async def _find_most_related_text_unit_from_entities(
|
||||
if this_edges:
|
||||
for e in this_edges:
|
||||
if (
|
||||
e[1] in all_one_hop_text_units_lookup
|
||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||
e[1] in all_one_hop_text_units_lookup
|
||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||
):
|
||||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||||
|
||||
@@ -902,9 +924,9 @@ async def _find_most_related_text_unit_from_entities(
|
||||
|
||||
|
||||
async def _find_most_related_edges_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
all_related_edges = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||||
@@ -942,11 +964,11 @@ async def _find_most_related_edges_from_entities(
|
||||
|
||||
|
||||
async def _get_edge_data(
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
):
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
|
||||
@@ -1044,9 +1066,9 @@ async def _get_edge_data(
|
||||
|
||||
|
||||
async def _find_most_related_entities_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
entity_names = []
|
||||
seen = set()
|
||||
@@ -1081,10 +1103,10 @@ async def _find_most_related_entities_from_relationships(
|
||||
|
||||
|
||||
async def _find_related_text_unit_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
@@ -1150,12 +1172,12 @@ def combine_contexts(entities, relationships, sources):
|
||||
|
||||
|
||||
async def naive_query(
|
||||
query,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
query,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
):
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -1213,7 +1235,7 @@ async def naive_query(
|
||||
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response[len(sys_prompt) :]
|
||||
response[len(sys_prompt):]
|
||||
.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
@@ -1241,15 +1263,15 @@ async def naive_query(
|
||||
|
||||
|
||||
async def mix_kg_vector_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
@@ -1274,7 +1296,7 @@ async def mix_kg_vector_query(
|
||||
# Reuse keyword extraction logic from kg_query
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(
|
||||
PROMPTS["keywords_extraction_examples"]
|
||||
PROMPTS["keywords_extraction_examples"]
|
||||
):
|
||||
examples = "\n".join(
|
||||
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
||||
|
Reference in New Issue
Block a user