Merge branch 'main' into add-env-settings

This commit is contained in:
yangdx
2025-02-16 22:34:39 +08:00
25 changed files with 1086 additions and 793 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import json
import re
from tqdm.asyncio import tqdm as tqdm_async
from typing import Any, Union
from typing import Any, AsyncIterator
from collections import Counter, defaultdict
from .utils import (
logger,
@@ -36,7 +38,7 @@ import time
def chunking_by_token_size(
content: str,
split_by_character: Union[str, None] = None,
split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_edge["description"])
already_keywords.extend(
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
)
# Handle the case where get_edge returns None or missing fields
if already_edge:
# Get weight with default 0.0 if missing
if "weight" in already_edge:
already_weights.append(already_edge["weight"])
else:
logger.warning(
f"Edge between {src_id} and {tgt_id} missing weight field"
)
already_weights.append(0.0)
# Get source_id with empty string default if missing or None
if "source_id" in already_edge and already_edge["source_id"] is not None:
already_source_ids.extend(
split_string_by_multi_markers(
already_edge["source_id"], [GRAPH_FIELD_SEP]
)
)
# Get description with empty string default if missing or None
if (
"description" in already_edge
and already_edge["description"] is not None
):
already_description.append(already_edge["description"])
# Get keywords with empty string default if missing or None
if "keywords" in already_edge and already_edge["keywords"] is not None:
already_keywords.extend(
split_string_by_multi_markers(
already_edge["keywords"], [GRAPH_FIELD_SEP]
)
)
# Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
sorted(
set(
[dp["description"] for dp in edges_data if dp.get("description")]
+ already_description
)
)
)
keywords = GRAPH_FIELD_SEP.join(
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
sorted(
set(
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
+ already_keywords
)
)
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
set(
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
+ already_source_ids
)
)
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(
@@ -295,9 +337,9 @@ async def extract_entities(
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
llm_response_cache: BaseKVStorage = None,
) -> Union[BaseGraphStorage, None]:
global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None,
) -> BaseGraphStorage | None:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
@@ -563,15 +605,15 @@ async def extract_entities(
async def kg_query(
query,
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
prompt: str = "",
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
prompt: str | None = None,
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -684,8 +726,8 @@ async def kg_query(
async def extract_keywords_only(
text: str,
param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -784,9 +826,9 @@ async def mix_kg_vector_query(
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str | AsyncIterator[str]:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1551,13 +1593,13 @@ def combine_contexts(entities, relationships, sources):
async def naive_query(
query,
query: str,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
):
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
@@ -1664,9 +1706,9 @@ async def kg_query_with_keywords(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str | AsyncIterator[str]:
"""
Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.