Update oracle_impl.py

This commit is contained in:
jin
2024-11-12 12:02:24 +08:00
parent d6443326c1
commit 77123be2a1

View File

@@ -503,79 +503,6 @@ class OracleGraphStorage(BaseGraphStorage):
#print("Node Edge not exist!",self.db.workspace, source_node_id) #print("Node Edge not exist!",self.db.workspace, source_node_id)
return [] return []
#################### INSERT method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""插入或更新节点"""
#print("go into upsert node method")
entity_name = node_id
entity_type = node_data["entity_type"]
description = node_data["description"]
source_id = node_data["source_id"]
content = entity_name+description
contents = [content]
batches = [
contents[i: i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_node"].format(
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
)
#print(merge_sql)
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
#self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
"""插入或更新边"""
#print("go into upsert edge method")
source_name = source_node_id
target_name = target_node_id
weight = edge_data["weight"]
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
content = keywords+source_name+target_name+description
contents = [content]
batches = [
contents[i: i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_edge"].format(
workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
)
#print(merge_sql)
await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
#self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""为节点生成向量"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def _node2vec_embed(self):
"""为节点生成向量"""
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
N_T = { N_T = {
"full_docs": "LIGHTRAG_DOC_FULL", "full_docs": "LIGHTRAG_DOC_FULL",