update insert custom kg
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.0.2"
|
||||
__version__ = "1.0.3"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -329,13 +329,39 @@ class LightRAG:
|
||||
async def ainsert_custom_kg(self, custom_kg: dict):
|
||||
update_storage = False
|
||||
try:
|
||||
# Insert chunks into vector storage
|
||||
all_chunks_data = {}
|
||||
chunk_to_source_map = {}
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
chunk_content = chunk_data["content"]
|
||||
source_id = chunk_data["source_id"]
|
||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||
|
||||
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
|
||||
all_chunks_data[chunk_id] = chunk_entry
|
||||
chunk_to_source_map[source_id] = chunk_id
|
||||
update_storage = True
|
||||
|
||||
if self.chunks_vdb is not None and all_chunks_data:
|
||||
await self.chunks_vdb.upsert(all_chunks_data)
|
||||
if self.text_chunks is not None and all_chunks_data:
|
||||
await self.text_chunks.upsert(all_chunks_data)
|
||||
|
||||
# Insert entities into knowledge graph
|
||||
all_entities_data = []
|
||||
for entity_data in custom_kg.get("entities", []):
|
||||
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||
description = entity_data.get("description", "No description provided")
|
||||
source_id = entity_data["source_id"]
|
||||
# source_id = entity_data["source_id"]
|
||||
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
|
||||
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
||||
|
||||
# Log if source_id is UNKNOWN
|
||||
if source_id == "UNKNOWN":
|
||||
logger.warning(
|
||||
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
|
||||
)
|
||||
|
||||
# Prepare node data
|
||||
node_data = {
|
||||
@@ -359,7 +385,15 @@ class LightRAG:
|
||||
description = relationship_data["description"]
|
||||
keywords = relationship_data["keywords"]
|
||||
weight = relationship_data.get("weight", 1.0)
|
||||
source_id = relationship_data["source_id"]
|
||||
# source_id = relationship_data["source_id"]
|
||||
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
|
||||
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
||||
|
||||
# Log if source_id is UNKNOWN
|
||||
if source_id == "UNKNOWN":
|
||||
logger.warning(
|
||||
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
|
||||
)
|
||||
|
||||
# Check if nodes exist in the knowledge graph
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
|
@@ -502,11 +502,12 @@ async def gpt_4o_mini_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def nvidia_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
result = await openai_complete_if_cache(
|
||||
"nvidia/llama-3.1-nemotron-70b-instruct", #context length 128k
|
||||
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
@@ -517,6 +518,7 @@ async def nvidia_openai_complete(
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
@@ -610,12 +612,12 @@ async def openai_embedding(
|
||||
)
|
||||
async def nvidia_openai_embedding(
|
||||
texts: list[str],
|
||||
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", #refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
|
||||
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
api_key: str = None,
|
||||
input_type: str = "passage", #query for retrieval, passage for embedding
|
||||
trunc: str = "NONE", #NONE or START or END
|
||||
encode: str = "float" #float or base64
|
||||
input_type: str = "passage", # query for retrieval, passage for embedding
|
||||
trunc: str = "NONE", # NONE or START or END
|
||||
encode: str = "float", # float or base64
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
@@ -624,10 +626,14 @@ async def nvidia_openai_embedding(
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format=encode, extra_body={"input_type": input_type, "truncate": trunc}
|
||||
model=model,
|
||||
input=texts,
|
||||
encoding_format=encode,
|
||||
extra_body={"input_type": input_type, "truncate": trunc},
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
|
@@ -297,7 +297,9 @@ async def extract_entities(
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
# hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
|
||||
hint_prompt = entity_extract_prompt.format(**context_base, input_text="{input_text}").format(**context_base, input_text=content)
|
||||
hint_prompt = entity_extract_prompt.format(
|
||||
**context_base, input_text="{input_text}"
|
||||
).format(**context_base, input_text=content)
|
||||
|
||||
final_result = await use_llm_func(hint_prompt)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
@@ -949,7 +951,6 @@ async def _find_related_text_unit_from_relationships(
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
for dp in edge_datas
|
||||
]
|
||||
|
||||
all_text_units_lookup = {}
|
||||
|
||||
for index, unit_list in enumerate(text_units):
|
||||
|
Reference in New Issue
Block a user