Merge branch 'main' into add-env-settings
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from typing import Any, Union
|
||||
from typing import Any, AsyncIterator
|
||||
from collections import Counter, defaultdict
|
||||
from .utils import (
|
||||
logger,
|
||||
@@ -36,7 +38,7 @@ import time
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str,
|
||||
split_by_character: Union[str, None] = None,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
overlap_token_size: int = 128,
|
||||
max_token_size: int = 1024,
|
||||
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
|
||||
|
||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||
already_weights.append(already_edge["weight"])
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_description.append(already_edge["description"])
|
||||
already_keywords.extend(
|
||||
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
# Handle the case where get_edge returns None or missing fields
|
||||
if already_edge:
|
||||
# Get weight with default 0.0 if missing
|
||||
if "weight" in already_edge:
|
||||
already_weights.append(already_edge["weight"])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Edge between {src_id} and {tgt_id} missing weight field"
|
||||
)
|
||||
already_weights.append(0.0)
|
||||
|
||||
# Get source_id with empty string default if missing or None
|
||||
if "source_id" in already_edge and already_edge["source_id"] is not None:
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
)
|
||||
|
||||
# Get description with empty string default if missing or None
|
||||
if (
|
||||
"description" in already_edge
|
||||
and already_edge["description"] is not None
|
||||
):
|
||||
already_description.append(already_edge["description"])
|
||||
|
||||
# Get keywords with empty string default if missing or None
|
||||
if "keywords" in already_edge and already_edge["keywords"] is not None:
|
||||
already_keywords.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
)
|
||||
|
||||
# Process edges_data with None checks
|
||||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||||
description = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
||||
sorted(
|
||||
set(
|
||||
[dp["description"] for dp in edges_data if dp.get("description")]
|
||||
+ already_description
|
||||
)
|
||||
)
|
||||
)
|
||||
keywords = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
||||
sorted(
|
||||
set(
|
||||
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
|
||||
+ already_keywords
|
||||
)
|
||||
)
|
||||
)
|
||||
source_id = GRAPH_FIELD_SEP.join(
|
||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
||||
set(
|
||||
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||||
+ already_source_ids
|
||||
)
|
||||
)
|
||||
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
@@ -295,9 +337,9 @@ async def extract_entities(
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage = None,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
global_config: dict[str, str],
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> BaseGraphStorage | None:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
@@ -563,15 +605,15 @@ async def extract_entities(
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query,
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
prompt: str = "",
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> str:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -684,8 +726,8 @@ async def kg_query(
|
||||
async def extract_keywords_only(
|
||||
text: str,
|
||||
param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||
@@ -784,9 +826,9 @@ async def mix_kg_vector_query(
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
|
||||
@@ -1551,13 +1593,13 @@ def combine_contexts(entities, relationships, sources):
|
||||
|
||||
|
||||
async def naive_query(
|
||||
query,
|
||||
query: str,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
):
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
@@ -1664,9 +1706,9 @@ async def kg_query_with_keywords(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Refactored kg_query that does NOT extract keywords by itself.
|
||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||
|
Reference in New Issue
Block a user