Update oracle_impl.py
This commit is contained in:
@@ -502,79 +502,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
#print("Node Edge not exist!",self.db.workspace, source_node_id)
|
||||
return []
|
||||
|
||||
#################### INSERT method ################
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
"""插入或更新节点"""
|
||||
#print("go into upsert node method")
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
source_id = node_data["source_id"]
|
||||
content = entity_name+description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i: i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
||||
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
|
||||
)
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
|
||||
#self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
"""插入或更新边"""
|
||||
#print("go into upsert edge method")
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
content = keywords+source_name+target_name+description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i: i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
||||
workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
|
||||
)
|
||||
#print(merge_sql)
|
||||
await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
|
||||
#self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
"""为节点生成向量"""
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
"""为节点生成向量"""
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
|
||||
N_T = {
|
||||
|
Reference in New Issue
Block a user