Update oracle_impl.py
This commit is contained in:
@@ -502,79 +502,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
else:
|
else:
|
||||||
#print("Node Edge not exist!",self.db.workspace, source_node_id)
|
#print("Node Edge not exist!",self.db.workspace, source_node_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
#################### INSERT method ################
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
|
||||||
"""插入或更新节点"""
|
|
||||||
#print("go into upsert node method")
|
|
||||||
entity_name = node_id
|
|
||||||
entity_type = node_data["entity_type"]
|
|
||||||
description = node_data["description"]
|
|
||||||
source_id = node_data["source_id"]
|
|
||||||
content = entity_name+description
|
|
||||||
contents = [content]
|
|
||||||
batches = [
|
|
||||||
contents[i: i + self._max_batch_size]
|
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
|
||||||
]
|
|
||||||
embeddings_list = await asyncio.gather(
|
|
||||||
*[self.embedding_func(batch) for batch in batches]
|
|
||||||
)
|
|
||||||
embeddings = np.concatenate(embeddings_list)
|
|
||||||
content_vector = embeddings[0]
|
|
||||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
|
||||||
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
|
|
||||||
)
|
|
||||||
#print(merge_sql)
|
|
||||||
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
|
|
||||||
#self._graph.add_node(node_id, **node_data)
|
|
||||||
|
|
||||||
async def upsert_edge(
|
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
|
||||||
):
|
|
||||||
"""插入或更新边"""
|
|
||||||
#print("go into upsert edge method")
|
|
||||||
source_name = source_node_id
|
|
||||||
target_name = target_node_id
|
|
||||||
weight = edge_data["weight"]
|
|
||||||
keywords = edge_data["keywords"]
|
|
||||||
description = edge_data["description"]
|
|
||||||
source_chunk_id = edge_data["source_id"]
|
|
||||||
content = keywords+source_name+target_name+description
|
|
||||||
contents = [content]
|
|
||||||
batches = [
|
|
||||||
contents[i: i + self._max_batch_size]
|
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
|
||||||
]
|
|
||||||
embeddings_list = await asyncio.gather(
|
|
||||||
*[self.embedding_func(batch) for batch in batches]
|
|
||||||
)
|
|
||||||
embeddings = np.concatenate(embeddings_list)
|
|
||||||
content_vector = embeddings[0]
|
|
||||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
|
||||||
workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
|
|
||||||
)
|
|
||||||
#print(merge_sql)
|
|
||||||
await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
|
|
||||||
#self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
|
||||||
"""为节点生成向量"""
|
|
||||||
if algorithm not in self._node_embed_algorithms:
|
|
||||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
|
||||||
return await self._node_embed_algorithms[algorithm]()
|
|
||||||
|
|
||||||
async def _node2vec_embed(self):
|
|
||||||
"""为节点生成向量"""
|
|
||||||
from graspologic import embed
|
|
||||||
|
|
||||||
embeddings, nodes = embed.node2vec_embed(
|
|
||||||
self._graph,
|
|
||||||
**self.config["node2vec_params"],
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
|
||||||
return embeddings, nodes_ids
|
|
||||||
|
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
|
Reference in New Issue
Block a user