添加字符分割功能,在“insert”函数中如果增加参数split_by_character,则会按照split_by_character进行字符分割,此时如果每个分割后的chunk的tokens大于max_token_size,则会继续按token_size分割(todo:考虑字符分割后过短的chunk处理)
This commit is contained in:
@@ -45,6 +45,7 @@ from .storage import (
|
|||||||
|
|
||||||
from .prompt import GRAPH_FIELD_SEP
|
from .prompt import GRAPH_FIELD_SEP
|
||||||
|
|
||||||
|
|
||||||
# future KG integrations
|
# future KG integrations
|
||||||
|
|
||||||
# from .kg.ArangoDB_impl import (
|
# from .kg.ArangoDB_impl import (
|
||||||
@@ -167,7 +168,7 @@ class LightRAG:
|
|||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||||
llm_model_max_token_size: int = 32768
|
llm_model_max_token_size: int = 32768
|
||||||
llm_model_max_async: int = 16
|
llm_model_max_async: int = 16
|
||||||
llm_model_kwargs: dict = field(default_factory=dict)
|
llm_model_kwargs: dict = field(default_factory=dict)
|
||||||
@@ -267,7 +268,7 @@ class LightRAG:
|
|||||||
self.llm_model_func,
|
self.llm_model_func,
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache
|
||||||
if self.llm_response_cache
|
if self.llm_response_cache
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
and hasattr(self.llm_response_cache, "global_config")
|
||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
@@ -313,15 +314,16 @@ class LightRAG:
|
|||||||
"JsonDocStatusStorage": JsonDocStatusStorage,
|
"JsonDocStatusStorage": JsonDocStatusStorage,
|
||||||
}
|
}
|
||||||
|
|
||||||
def insert(self, string_or_strings):
|
def insert(self, string_or_strings, split_by_character=None):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character))
|
||||||
|
|
||||||
async def ainsert(self, string_or_strings):
|
async def ainsert(self, string_or_strings, split_by_character):
|
||||||
"""Insert documents with checkpoint support
|
"""Insert documents with checkpoint support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
string_or_strings: Single document string or list of document strings
|
string_or_strings: Single document string or list of document strings
|
||||||
|
split_by_character: if split_by_character is not None, split the string by character
|
||||||
"""
|
"""
|
||||||
if isinstance(string_or_strings, str):
|
if isinstance(string_or_strings, str):
|
||||||
string_or_strings = [string_or_strings]
|
string_or_strings = [string_or_strings]
|
||||||
@@ -355,10 +357,10 @@ class LightRAG:
|
|||||||
# Process documents in batches
|
# Process documents in batches
|
||||||
batch_size = self.addon_params.get("insert_batch_size", 10)
|
batch_size = self.addon_params.get("insert_batch_size", 10)
|
||||||
for i in range(0, len(new_docs), batch_size):
|
for i in range(0, len(new_docs), batch_size):
|
||||||
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
|
batch_docs = dict(list(new_docs.items())[i: i + batch_size])
|
||||||
|
|
||||||
for doc_id, doc in tqdm_async(
|
for doc_id, doc in tqdm_async(
|
||||||
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
|
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Update status to processing
|
# Update status to processing
|
||||||
@@ -379,6 +381,7 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
for dp in chunking_by_token_size(
|
for dp in chunking_by_token_size(
|
||||||
doc["content"],
|
doc["content"],
|
||||||
|
split_by_character=split_by_character,
|
||||||
overlap_token_size=self.chunk_overlap_token_size,
|
overlap_token_size=self.chunk_overlap_token_size,
|
||||||
max_token_size=self.chunk_token_size,
|
max_token_size=self.chunk_token_size,
|
||||||
tiktoken_model=self.tiktoken_model_name,
|
tiktoken_model=self.tiktoken_model_name,
|
||||||
@@ -545,7 +548,7 @@ class LightRAG:
|
|||||||
# Check if nodes exist in the knowledge graph
|
# Check if nodes exist in the knowledge graph
|
||||||
for need_insert_id in [src_id, tgt_id]:
|
for need_insert_id in [src_id, tgt_id]:
|
||||||
if not (
|
if not (
|
||||||
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
||||||
):
|
):
|
||||||
await self.chunk_entity_relation_graph.upsert_node(
|
await self.chunk_entity_relation_graph.upsert_node(
|
||||||
need_insert_id,
|
need_insert_id,
|
||||||
@@ -594,9 +597,9 @@ class LightRAG:
|
|||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
"content": dp["keywords"]
|
"content": dp["keywords"]
|
||||||
+ dp["src_id"]
|
+ dp["src_id"]
|
||||||
+ dp["tgt_id"]
|
+ dp["tgt_id"]
|
||||||
+ dp["description"],
|
+ dp["description"],
|
||||||
}
|
}
|
||||||
for dp in all_relationships_data
|
for dp in all_relationships_data
|
||||||
}
|
}
|
||||||
@@ -621,7 +624,7 @@ class LightRAG:
|
|||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache
|
||||||
if self.llm_response_cache
|
if self.llm_response_cache
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
and hasattr(self.llm_response_cache, "global_config")
|
||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
@@ -637,7 +640,7 @@ class LightRAG:
|
|||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache
|
||||||
if self.llm_response_cache
|
if self.llm_response_cache
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
and hasattr(self.llm_response_cache, "global_config")
|
||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
@@ -656,7 +659,7 @@ class LightRAG:
|
|||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache
|
||||||
if self.llm_response_cache
|
if self.llm_response_cache
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
and hasattr(self.llm_response_cache, "global_config")
|
||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
@@ -897,7 +900,7 @@ class LightRAG:
|
|||||||
dp
|
dp
|
||||||
for dp in self.entities_vdb.client_storage["data"]
|
for dp in self.entities_vdb.client_storage["data"]
|
||||||
if chunk_id
|
if chunk_id
|
||||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||||
]
|
]
|
||||||
if entities_with_chunk:
|
if entities_with_chunk:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -909,7 +912,7 @@ class LightRAG:
|
|||||||
dp
|
dp
|
||||||
for dp in self.relationships_vdb.client_storage["data"]
|
for dp in self.relationships_vdb.client_storage["data"]
|
||||||
if chunk_id
|
if chunk_id
|
||||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||||
]
|
]
|
||||||
if relations_with_chunk:
|
if relations_with_chunk:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -926,7 +929,7 @@ class LightRAG:
|
|||||||
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
||||||
|
|
||||||
async def get_entity_info(
|
async def get_entity_info(
|
||||||
self, entity_name: str, include_vector_data: bool = False
|
self, entity_name: str, include_vector_data: bool = False
|
||||||
):
|
):
|
||||||
"""Get detailed information of an entity
|
"""Get detailed information of an entity
|
||||||
|
|
||||||
@@ -977,7 +980,7 @@ class LightRAG:
|
|||||||
tracemalloc.stop()
|
tracemalloc.stop()
|
||||||
|
|
||||||
async def get_relation_info(
|
async def get_relation_info(
|
||||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||||
):
|
):
|
||||||
"""Get detailed information of a relationship
|
"""Get detailed information of a relationship
|
||||||
|
|
||||||
@@ -1019,7 +1022,7 @@ class LightRAG:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_relation_info_sync(
|
def get_relation_info_sync(
|
||||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||||
):
|
):
|
||||||
"""Synchronous version of getting relationship information
|
"""Synchronous version of getting relationship information
|
||||||
|
|
||||||
|
@@ -34,30 +34,52 @@ import time
|
|||||||
|
|
||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||||
):
|
):
|
||||||
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
||||||
results = []
|
results = []
|
||||||
for index, start in enumerate(
|
if split_by_character:
|
||||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
raw_chunks = content.split(split_by_character)
|
||||||
):
|
new_chunks = []
|
||||||
chunk_content = decode_tokens_by_tiktoken(
|
for chunk in raw_chunks:
|
||||||
tokens[start : start + max_token_size], model_name=tiktoken_model
|
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
|
||||||
)
|
if len(_tokens) > max_token_size:
|
||||||
results.append(
|
for start in range(0, len(_tokens), max_token_size - overlap_token_size):
|
||||||
{
|
chunk_content = decode_tokens_by_tiktoken(
|
||||||
"tokens": min(max_token_size, len(tokens) - start),
|
_tokens[start: start + max_token_size], model_name=tiktoken_model
|
||||||
"content": chunk_content.strip(),
|
)
|
||||||
"chunk_order_index": index,
|
new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content))
|
||||||
}
|
else:
|
||||||
)
|
new_chunks.append((len(_tokens), chunk))
|
||||||
|
for index, (_len, chunk) in enumerate(new_chunks):
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"tokens": _len,
|
||||||
|
"content": chunk.strip(),
|
||||||
|
"chunk_order_index": index,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for index, start in enumerate(
|
||||||
|
range(0, len(tokens), max_token_size - overlap_token_size)
|
||||||
|
):
|
||||||
|
chunk_content = decode_tokens_by_tiktoken(
|
||||||
|
tokens[start: start + max_token_size], model_name=tiktoken_model
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"tokens": min(max_token_size, len(tokens) - start),
|
||||||
|
"content": chunk_content.strip(),
|
||||||
|
"chunk_order_index": index,
|
||||||
|
}
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def _handle_entity_relation_summary(
|
async def _handle_entity_relation_summary(
|
||||||
entity_or_relation_name: str,
|
entity_or_relation_name: str,
|
||||||
description: str,
|
description: str,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
) -> str:
|
) -> str:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||||
@@ -86,8 +108,8 @@ async def _handle_entity_relation_summary(
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_single_entity_extraction(
|
async def _handle_single_entity_extraction(
|
||||||
record_attributes: list[str],
|
record_attributes: list[str],
|
||||||
chunk_key: str,
|
chunk_key: str,
|
||||||
):
|
):
|
||||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||||
return None
|
return None
|
||||||
@@ -107,8 +129,8 @@ async def _handle_single_entity_extraction(
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_single_relationship_extraction(
|
async def _handle_single_relationship_extraction(
|
||||||
record_attributes: list[str],
|
record_attributes: list[str],
|
||||||
chunk_key: str,
|
chunk_key: str,
|
||||||
):
|
):
|
||||||
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
||||||
return None
|
return None
|
||||||
@@ -134,10 +156,10 @@ async def _handle_single_relationship_extraction(
|
|||||||
|
|
||||||
|
|
||||||
async def _merge_nodes_then_upsert(
|
async def _merge_nodes_then_upsert(
|
||||||
entity_name: str,
|
entity_name: str,
|
||||||
nodes_data: list[dict],
|
nodes_data: list[dict],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
):
|
):
|
||||||
already_entity_types = []
|
already_entity_types = []
|
||||||
already_source_ids = []
|
already_source_ids = []
|
||||||
@@ -181,11 +203,11 @@ async def _merge_nodes_then_upsert(
|
|||||||
|
|
||||||
|
|
||||||
async def _merge_edges_then_upsert(
|
async def _merge_edges_then_upsert(
|
||||||
src_id: str,
|
src_id: str,
|
||||||
tgt_id: str,
|
tgt_id: str,
|
||||||
edges_data: list[dict],
|
edges_data: list[dict],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
):
|
):
|
||||||
already_weights = []
|
already_weights = []
|
||||||
already_source_ids = []
|
already_source_ids = []
|
||||||
@@ -248,12 +270,12 @@ async def _merge_edges_then_upsert(
|
|||||||
|
|
||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
chunks: dict[str, TextChunkSchema],
|
chunks: dict[str, TextChunkSchema],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entity_vdb: BaseVectorStorage,
|
entity_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
llm_response_cache: BaseKVStorage = None,
|
llm_response_cache: BaseKVStorage = None,
|
||||||
) -> Union[BaseGraphStorage, None]:
|
) -> Union[BaseGraphStorage, None]:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
@@ -305,13 +327,13 @@ async def extract_entities(
|
|||||||
already_relations = 0
|
already_relations = 0
|
||||||
|
|
||||||
async def _user_llm_func_with_cache(
|
async def _user_llm_func_with_cache(
|
||||||
input_text: str, history_messages: list[dict[str, str]] = None
|
input_text: str, history_messages: list[dict[str, str]] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||||||
need_to_restore = False
|
need_to_restore = False
|
||||||
if (
|
if (
|
||||||
global_config["embedding_cache_config"]
|
global_config["embedding_cache_config"]
|
||||||
and global_config["embedding_cache_config"]["enabled"]
|
and global_config["embedding_cache_config"]["enabled"]
|
||||||
):
|
):
|
||||||
new_config = global_config.copy()
|
new_config = global_config.copy()
|
||||||
new_config["embedding_cache_config"] = None
|
new_config["embedding_cache_config"] = None
|
||||||
@@ -413,7 +435,7 @@ async def extract_entities(
|
|||||||
already_relations += len(maybe_edges)
|
already_relations += len(maybe_edges)
|
||||||
now_ticks = PROMPTS["process_tickers"][
|
now_ticks = PROMPTS["process_tickers"][
|
||||||
already_processed % len(PROMPTS["process_tickers"])
|
already_processed % len(PROMPTS["process_tickers"])
|
||||||
]
|
]
|
||||||
print(
|
print(
|
||||||
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
||||||
end="",
|
end="",
|
||||||
@@ -423,10 +445,10 @@ async def extract_entities(
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for result in tqdm_async(
|
for result in tqdm_async(
|
||||||
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
|
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
|
||||||
total=len(ordered_chunks),
|
total=len(ordered_chunks),
|
||||||
desc="Extracting entities from chunks",
|
desc="Extracting entities from chunks",
|
||||||
unit="chunk",
|
unit="chunk",
|
||||||
):
|
):
|
||||||
results.append(await result)
|
results.append(await result)
|
||||||
|
|
||||||
@@ -440,32 +462,32 @@ async def extract_entities(
|
|||||||
logger.info("Inserting entities into storage...")
|
logger.info("Inserting entities into storage...")
|
||||||
all_entities_data = []
|
all_entities_data = []
|
||||||
for result in tqdm_async(
|
for result in tqdm_async(
|
||||||
asyncio.as_completed(
|
asyncio.as_completed(
|
||||||
[
|
[
|
||||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||||
for k, v in maybe_nodes.items()
|
for k, v in maybe_nodes.items()
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
total=len(maybe_nodes),
|
total=len(maybe_nodes),
|
||||||
desc="Inserting entities",
|
desc="Inserting entities",
|
||||||
unit="entity",
|
unit="entity",
|
||||||
):
|
):
|
||||||
all_entities_data.append(await result)
|
all_entities_data.append(await result)
|
||||||
|
|
||||||
logger.info("Inserting relationships into storage...")
|
logger.info("Inserting relationships into storage...")
|
||||||
all_relationships_data = []
|
all_relationships_data = []
|
||||||
for result in tqdm_async(
|
for result in tqdm_async(
|
||||||
asyncio.as_completed(
|
asyncio.as_completed(
|
||||||
[
|
[
|
||||||
_merge_edges_then_upsert(
|
_merge_edges_then_upsert(
|
||||||
k[0], k[1], v, knowledge_graph_inst, global_config
|
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||||
)
|
)
|
||||||
for k, v in maybe_edges.items()
|
for k, v in maybe_edges.items()
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
total=len(maybe_edges),
|
total=len(maybe_edges),
|
||||||
desc="Inserting relationships",
|
desc="Inserting relationships",
|
||||||
unit="relationship",
|
unit="relationship",
|
||||||
):
|
):
|
||||||
all_relationships_data.append(await result)
|
all_relationships_data.append(await result)
|
||||||
|
|
||||||
@@ -496,9 +518,9 @@ async def extract_entities(
|
|||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
"content": dp["keywords"]
|
"content": dp["keywords"]
|
||||||
+ dp["src_id"]
|
+ dp["src_id"]
|
||||||
+ dp["tgt_id"]
|
+ dp["tgt_id"]
|
||||||
+ dp["description"],
|
+ dp["description"],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
||||||
},
|
},
|
||||||
@@ -511,14 +533,14 @@ async def extract_entities(
|
|||||||
|
|
||||||
|
|
||||||
async def kg_query(
|
async def kg_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
@@ -638,12 +660,12 @@ async def kg_query(
|
|||||||
|
|
||||||
|
|
||||||
async def _build_query_context(
|
async def _build_query_context(
|
||||||
query: list,
|
query: list,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
|
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
|
||||||
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
|
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
|
||||||
@@ -696,9 +718,9 @@ async def _build_query_context(
|
|||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
hl_entities_context == ""
|
hl_entities_context == ""
|
||||||
and hl_relations_context == ""
|
and hl_relations_context == ""
|
||||||
and hl_text_units_context == ""
|
and hl_text_units_context == ""
|
||||||
):
|
):
|
||||||
logger.warn("No high level context found. Switching to local mode.")
|
logger.warn("No high level context found. Switching to local mode.")
|
||||||
query_param.mode = "local"
|
query_param.mode = "local"
|
||||||
@@ -737,11 +759,11 @@ async def _build_query_context(
|
|||||||
|
|
||||||
|
|
||||||
async def _get_node_data(
|
async def _get_node_data(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
# get similar entities
|
# get similar entities
|
||||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||||
@@ -828,10 +850,10 @@ async def _get_node_data(
|
|||||||
|
|
||||||
|
|
||||||
async def _find_most_related_text_unit_from_entities(
|
async def _find_most_related_text_unit_from_entities(
|
||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
):
|
):
|
||||||
text_units = [
|
text_units = [
|
||||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||||
@@ -871,8 +893,8 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
if this_edges:
|
if this_edges:
|
||||||
for e in this_edges:
|
for e in this_edges:
|
||||||
if (
|
if (
|
||||||
e[1] in all_one_hop_text_units_lookup
|
e[1] in all_one_hop_text_units_lookup
|
||||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||||
):
|
):
|
||||||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||||||
|
|
||||||
@@ -902,9 +924,9 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
|
|
||||||
|
|
||||||
async def _find_most_related_edges_from_entities(
|
async def _find_most_related_edges_from_entities(
|
||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
):
|
):
|
||||||
all_related_edges = await asyncio.gather(
|
all_related_edges = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||||||
@@ -942,11 +964,11 @@ async def _find_most_related_edges_from_entities(
|
|||||||
|
|
||||||
|
|
||||||
async def _get_edge_data(
|
async def _get_edge_data(
|
||||||
keywords,
|
keywords,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||||
|
|
||||||
@@ -1044,9 +1066,9 @@ async def _get_edge_data(
|
|||||||
|
|
||||||
|
|
||||||
async def _find_most_related_entities_from_relationships(
|
async def _find_most_related_entities_from_relationships(
|
||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
):
|
):
|
||||||
entity_names = []
|
entity_names = []
|
||||||
seen = set()
|
seen = set()
|
||||||
@@ -1081,10 +1103,10 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
|
|
||||||
|
|
||||||
async def _find_related_text_unit_from_relationships(
|
async def _find_related_text_unit_from_relationships(
|
||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
):
|
):
|
||||||
text_units = [
|
text_units = [
|
||||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||||
@@ -1150,12 +1172,12 @@ def combine_contexts(entities, relationships, sources):
|
|||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
query,
|
query,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage = None,
|
||||||
):
|
):
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
@@ -1213,7 +1235,7 @@ async def naive_query(
|
|||||||
|
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = (
|
response = (
|
||||||
response[len(sys_prompt) :]
|
response[len(sys_prompt):]
|
||||||
.replace(sys_prompt, "")
|
.replace(sys_prompt, "")
|
||||||
.replace("user", "")
|
.replace("user", "")
|
||||||
.replace("model", "")
|
.replace("model", "")
|
||||||
@@ -1241,15 +1263,15 @@ async def naive_query(
|
|||||||
|
|
||||||
|
|
||||||
async def mix_kg_vector_query(
|
async def mix_kg_vector_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||||
@@ -1274,7 +1296,7 @@ async def mix_kg_vector_query(
|
|||||||
# Reuse keyword extraction logic from kg_query
|
# Reuse keyword extraction logic from kg_query
|
||||||
example_number = global_config["addon_params"].get("example_number", None)
|
example_number = global_config["addon_params"].get("example_number", None)
|
||||||
if example_number and example_number < len(
|
if example_number and example_number < len(
|
||||||
PROMPTS["keywords_extraction_examples"]
|
PROMPTS["keywords_extraction_examples"]
|
||||||
):
|
):
|
||||||
examples = "\n".join(
|
examples = "\n".join(
|
||||||
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
||||||
|
Reference in New Issue
Block a user