Merge branch 'main' of https://github.com/jin38324/LightRAG
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__version__ = "1.0.1"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -86,9 +86,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edgeExists"]
|
||||
|
||||
def close(self):
|
||||
self._driver.close()
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async with self._driver.session() as session:
|
||||
entity_name_label = node_id.strip('"')
|
||||
@@ -214,6 +211,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ClientError,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@@ -114,7 +114,7 @@ class OracleDB:
|
||||
|
||||
logger.info("Finished check all tables in Oracle database")
|
||||
|
||||
async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
|
||||
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
|
||||
async with self.pool.acquire() as connection:
|
||||
connection.inputtypehandler = self.input_type_handler
|
||||
connection.outputtypehandler = self.output_type_handler
|
||||
@@ -173,10 +173,11 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
"""根据 id 获取 doc_full 数据."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace":self.db.workspace, "id":id}
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
|
||||
workspace=self.db.workspace, id=id
|
||||
)
|
||||
# print("get_by_id:"+SQL)
|
||||
res = await self.db.query(SQL,params)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
data = res # {"data":res}
|
||||
# print (data)
|
||||
@@ -187,11 +188,11 @@ class OracleKVStorage(BaseKVStorage):
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
|
||||
params = {"workspace":self.db.workspace}
|
||||
#print("get_by_ids:"+SQL)
|
||||
#print(params)
|
||||
res = await self.db.query(SQL,params, multirows=True)
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
# print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
@@ -201,16 +202,12 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
||||
ids=",".join([f"'{id}'" for id in keys]))
|
||||
params = {"workspace":self.db.workspace}
|
||||
try:
|
||||
await self.db.query(SQL, params)
|
||||
except Exception as e:
|
||||
logger.error(f"Oracle database error: {e}")
|
||||
print(SQL)
|
||||
print(params)
|
||||
res = await self.db.query(SQL, params,multirows=True)
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=N_T[self.namespace],
|
||||
workspace=self.db.workspace,
|
||||
ids=",".join([f"'{k}'" for k in keys]),
|
||||
)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
data = None
|
||||
if res:
|
||||
exist_keys = [key["id"] for key in res]
|
||||
@@ -247,29 +244,27 @@ class OracleKVStorage(BaseKVStorage):
|
||||
d["__vector__"] = embeddings[i]
|
||||
# print(list_data)
|
||||
for item in list_data:
|
||||
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
||||
data = {"check_id":item["__id__"],
|
||||
"id":item["__id__"],
|
||||
"content":item["content"],
|
||||
"workspace":self.db.workspace,
|
||||
"tokens":item["tokens"],
|
||||
"chunk_order_index":item["chunk_order_index"],
|
||||
"full_doc_id":item["full_doc_id"],
|
||||
"content_vector":item["__vector__"]
|
||||
}
|
||||
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
|
||||
|
||||
values = [
|
||||
item["__id__"],
|
||||
item["content"],
|
||||
self.db.workspace,
|
||||
item["tokens"],
|
||||
item["chunk_order_index"],
|
||||
item["full_doc_id"],
|
||||
item["__vector__"],
|
||||
]
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
if self.namespace == "full_docs":
|
||||
for k, v in self._data.items():
|
||||
# values.clear()
|
||||
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
||||
data = {
|
||||
"check_id":k,
|
||||
"id":k,
|
||||
"content":v["content"],
|
||||
"workspace":self.db.workspace
|
||||
}
|
||||
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
||||
check_id=k,
|
||||
)
|
||||
values = [k, self._data[k]["content"], self.db.workspace]
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, data)
|
||||
return left_data
|
||||
@@ -301,17 +296,18 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# 转换精度
|
||||
dtype = str(embedding.dtype).upper()
|
||||
dimension = embedding.shape[0]
|
||||
embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
|
||||
embedding_string = ", ".join(map(str, embedding.tolist()))
|
||||
|
||||
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
||||
params = {
|
||||
"embedding_string": embedding_string,
|
||||
"workspace": self.db.workspace,
|
||||
"top_k": top_k,
|
||||
"better_than_threshold": self.cosine_better_than_threshold,
|
||||
}
|
||||
SQL = SQL_TEMPLATES[self.namespace].format(
|
||||
embedding_string=embedding_string,
|
||||
dimension=dimension,
|
||||
dtype=dtype,
|
||||
workspace=self.db.workspace,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
# print(SQL)
|
||||
results = await self.db.query(SQL,params=params, multirows=True)
|
||||
results = await self.db.query(SQL, multirows=True)
|
||||
# print("vector search result:",results)
|
||||
return results
|
||||
|
||||
@@ -346,18 +342,22 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_node"]
|
||||
data = {
|
||||
"workspace":self.db.workspace,
|
||||
"name":entity_name,
|
||||
"entity_type":entity_type,
|
||||
"description":description,
|
||||
"source_chunk_id":source_id,
|
||||
"content":content,
|
||||
"content_vector":content_vector
|
||||
}
|
||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
||||
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
|
||||
)
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql,data)
|
||||
await self.db.execute(
|
||||
merge_sql,
|
||||
[
|
||||
self.db.workspace,
|
||||
entity_name,
|
||||
entity_type,
|
||||
description,
|
||||
source_id,
|
||||
content,
|
||||
content_vector,
|
||||
],
|
||||
)
|
||||
# self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
@@ -371,8 +371,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
|
||||
|
||||
content = keywords + source_name + target_name + description
|
||||
contents = [content]
|
||||
batches = [
|
||||
@@ -384,20 +382,27 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_edge"]
|
||||
data = {
|
||||
"workspace":self.db.workspace,
|
||||
"source_name":source_name,
|
||||
"target_name":target_name,
|
||||
"weight":weight,
|
||||
"keywords":keywords,
|
||||
"description":description,
|
||||
"source_chunk_id":source_chunk_id,
|
||||
"content":content,
|
||||
"content_vector":content_vector
|
||||
}
|
||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
||||
workspace=self.db.workspace,
|
||||
source_name=source_name,
|
||||
target_name=target_name,
|
||||
source_chunk_id=source_chunk_id,
|
||||
)
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql,data)
|
||||
await self.db.execute(
|
||||
merge_sql,
|
||||
[
|
||||
self.db.workspace,
|
||||
source_name,
|
||||
target_name,
|
||||
weight,
|
||||
keywords,
|
||||
description,
|
||||
source_chunk_id,
|
||||
content,
|
||||
content_vector,
|
||||
],
|
||||
)
|
||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
@@ -427,14 +432,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
#################### query method #################
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""根据节点id检查节点是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_node"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"node_id":node_id
|
||||
}
|
||||
SQL = SQL_TEMPLATES["has_node"].format(
|
||||
workspace=self.db.workspace, node_id=node_id
|
||||
)
|
||||
# print(SQL)
|
||||
# print(self.db.workspace, node_id)
|
||||
res = await self.db.query(SQL,params)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
# print("Node exist!",res)
|
||||
return True
|
||||
@@ -444,14 +447,13 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""根据源和目标节点id检查边是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_edge"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"source_node_id":source_node_id,
|
||||
"target_node_id":target_node_id
|
||||
}
|
||||
SQL = SQL_TEMPLATES["has_edge"].format(
|
||||
workspace=self.db.workspace,
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL,params)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
# print("Edge exist!",res)
|
||||
return True
|
||||
@@ -461,13 +463,11 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""根据节点id获取节点的度"""
|
||||
SQL = SQL_TEMPLATES["node_degree"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"node_id":node_id
|
||||
}
|
||||
SQL = SQL_TEMPLATES["node_degree"].format(
|
||||
workspace=self.db.workspace, node_id=node_id
|
||||
)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL,params)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
# print("Node degree",res["degree"])
|
||||
return res["degree"]
|
||||
@@ -483,14 +483,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""根据节点id获取节点数据"""
|
||||
SQL = SQL_TEMPLATES["get_node"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"node_id":node_id
|
||||
}
|
||||
SQL = SQL_TEMPLATES["get_node"].format(
|
||||
workspace=self.db.workspace, node_id=node_id
|
||||
)
|
||||
# print(self.db.workspace, node_id)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL,params)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
# print("Get node!",self.db.workspace, node_id,res)
|
||||
return res
|
||||
@@ -502,13 +500,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""根据源和目标节点id获取边"""
|
||||
SQL = SQL_TEMPLATES["get_edge"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"source_node_id":source_node_id,
|
||||
"target_node_id":target_node_id
|
||||
}
|
||||
res = await self.db.query(SQL,params)
|
||||
SQL = SQL_TEMPLATES["get_edge"].format(
|
||||
workspace=self.db.workspace,
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
)
|
||||
res = await self.db.query(SQL)
|
||||
if res:
|
||||
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
||||
return res
|
||||
@@ -519,12 +516,10 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
"""根据节点id获取节点的所有边"""
|
||||
if await self.has_node(source_node_id):
|
||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"source_node_id":source_node_id
|
||||
}
|
||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||
SQL = SQL_TEMPLATES["get_node_edges"].format(
|
||||
workspace=self.db.workspace, source_node_id=source_node_id
|
||||
)
|
||||
res = await self.db.query(sql=SQL, multirows=True)
|
||||
if res:
|
||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
||||
@@ -532,29 +527,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
||||
return []
|
||||
|
||||
async def get_all_nodes(self, limit: int):
|
||||
"""查询所有节点"""
|
||||
SQL = SQL_TEMPLATES["get_all_nodes"]
|
||||
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
||||
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
|
||||
async def get_all_edges(self, limit: int):
|
||||
"""查询所有边"""
|
||||
SQL = SQL_TEMPLATES["get_all_edges"]
|
||||
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
||||
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
|
||||
async def get_statistics(self):
|
||||
SQL = SQL_TEMPLATES["get_statistics"]
|
||||
params = {"workspace":self.db.workspace}
|
||||
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
|
||||
N_T = {
|
||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||
@@ -726,37 +699,5 @@ SQL_TEMPLATES = {
|
||||
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
||||
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
|
||||
"get_all_nodes":"""WITH t0 AS (
|
||||
SELECT name AS id, entity_type AS label, entity_type, description,
|
||||
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
|
||||
FROM lightrag_graph_nodes
|
||||
WHERE workspace = :workspace
|
||||
ORDER BY createtime DESC fetch first :limit rows only
|
||||
), t1 AS (
|
||||
SELECT t0.id, source_chunk_id
|
||||
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
|
||||
), t2 AS (
|
||||
SELECT t1.id, LISTAGG(t2.content, '\n') content
|
||||
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
|
||||
GROUP BY t1.id
|
||||
)
|
||||
SELECT t0.id, label, entity_type, description, t2.content
|
||||
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
|
||||
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
||||
t1.weight,t1.DESCRIPTION,t2.content
|
||||
FROM LIGHTRAG_GRAPH_EDGES t1
|
||||
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
||||
WHERE t1.workspace=:workspace
|
||||
order by t1.CREATETIME DESC
|
||||
fetch first :limit rows only""",
|
||||
"get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
|
||||
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
|
||||
FROM (
|
||||
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
|
||||
UNION
|
||||
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
|
||||
)""",
|
||||
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
|
||||
}
|
||||
|
@@ -172,9 +172,7 @@ class LightRAG:
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||
namespace="chunk_entity_relation",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
namespace="chunk_entity_relation", global_config=asdict(self)
|
||||
)
|
||||
####
|
||||
# add embedding func by walter over
|
||||
@@ -226,6 +224,7 @@ class LightRAG:
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
update_storage = False
|
||||
try:
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
@@ -239,6 +238,7 @@ class LightRAG:
|
||||
if not len(new_docs):
|
||||
logger.warning("All docs are already in the storage")
|
||||
return
|
||||
update_storage = True
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
inserting_chunks = {}
|
||||
@@ -285,7 +285,8 @@ class LightRAG:
|
||||
await self.full_docs.upsert(new_docs)
|
||||
await self.text_chunks.upsert(inserting_chunks)
|
||||
finally:
|
||||
await self._insert_done()
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
|
@@ -696,13 +696,17 @@ async def bedrock_embedding(
|
||||
|
||||
|
||||
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
device = next(embed_model.parameters()).device
|
||||
input_ids = tokenizer(
|
||||
texts, return_tensors="pt", padding=True, truncation=True
|
||||
).input_ids
|
||||
).input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
outputs = embed_model(input_ids)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
return embeddings.detach().numpy()
|
||||
if embeddings.dtype == torch.bfloat16:
|
||||
return embeddings.detach().to(torch.float32).cpu().numpy()
|
||||
else:
|
||||
return embeddings.detach().cpu().numpy()
|
||||
|
||||
|
||||
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
|
@@ -662,24 +662,20 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_text_units_lookup = {}
|
||||
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
||||
for c_id in this_text_units:
|
||||
if c_id in all_text_units_lookup:
|
||||
continue
|
||||
relation_counts = 0
|
||||
if this_edges: # Add check for None edges
|
||||
if c_id not in all_text_units_lookup:
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": await text_chunks_db.get_by_id(c_id),
|
||||
"order": index,
|
||||
"relation_counts": 0,
|
||||
}
|
||||
|
||||
if this_edges:
|
||||
for e in this_edges:
|
||||
if (
|
||||
e[1] in all_one_hop_text_units_lookup
|
||||
and c_id in all_one_hop_text_units_lookup[e[1]]
|
||||
):
|
||||
relation_counts += 1
|
||||
|
||||
chunk_data = await text_chunks_db.get_by_id(c_id)
|
||||
if chunk_data is not None and "content" in chunk_data: # Add content check
|
||||
all_text_units_lookup[c_id] = {
|
||||
"data": chunk_data,
|
||||
"order": index,
|
||||
"relation_counts": relation_counts,
|
||||
}
|
||||
all_text_units_lookup[c_id]["relation_counts"] += 1
|
||||
|
||||
# Filter out None values and ensure data has content
|
||||
all_text_units = [
|
||||
@@ -714,10 +710,16 @@ async def _find_most_related_edges_from_entities(
|
||||
all_related_edges = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||||
)
|
||||
all_edges = set()
|
||||
all_edges = []
|
||||
seen = set()
|
||||
|
||||
for this_edges in all_related_edges:
|
||||
all_edges.update([tuple(sorted(e)) for e in this_edges])
|
||||
all_edges = list(all_edges)
|
||||
for e in this_edges:
|
||||
sorted_edge = tuple(sorted(e))
|
||||
if sorted_edge not in seen:
|
||||
seen.add(sorted_edge)
|
||||
all_edges.append(sorted_edge)
|
||||
|
||||
all_edges_pack = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
||||
)
|
||||
@@ -828,10 +830,16 @@ async def _find_most_related_entities_from_relationships(
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
entity_names = set()
|
||||
entity_names = []
|
||||
seen = set()
|
||||
|
||||
for e in edge_datas:
|
||||
entity_names.add(e["src_id"])
|
||||
entity_names.add(e["tgt_id"])
|
||||
if e["src_id"] not in seen:
|
||||
entity_names.append(e["src_id"])
|
||||
seen.add(e["src_id"])
|
||||
if e["tgt_id"] not in seen:
|
||||
entity_names.append(e["tgt_id"])
|
||||
seen.add(e["tgt_id"])
|
||||
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
||||
|
@@ -290,13 +290,19 @@ def process_combine_contexts(hl, ll):
|
||||
if list_ll:
|
||||
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
||||
|
||||
combined_sources_set = set(filter(None, list_hl + list_ll))
|
||||
combined_sources = []
|
||||
seen = set()
|
||||
|
||||
combined_sources = [",\t".join(header)]
|
||||
for item in list_hl + list_ll:
|
||||
if item and item not in seen:
|
||||
combined_sources.append(item)
|
||||
seen.add(item)
|
||||
|
||||
for i, item in enumerate(combined_sources_set, start=1):
|
||||
combined_sources.append(f"{i},\t{item}")
|
||||
combined_sources_result = [",\t".join(header)]
|
||||
|
||||
combined_sources = "\n".join(combined_sources)
|
||||
for i, item in enumerate(combined_sources, start=1):
|
||||
combined_sources_result.append(f"{i},\t{item}")
|
||||
|
||||
return combined_sources
|
||||
combined_sources_result = "\n".join(combined_sources_result)
|
||||
|
||||
return combined_sources_result
|
||||
|
Reference in New Issue
Block a user