adding neo4j integration
This commit is contained in:
@@ -97,14 +97,66 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class PineConeVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
import os
|
||||
from pinecone import Pinecone
|
||||
|
||||
pc = Pinecone() #api_key=os.environ.get('PINECONE_API_KEY'))
|
||||
# From here on, everything is identical to the REST-based SDK.
|
||||
self._client = pc.Index(host=self._client_pinecone_host)#'my-index-8833ca1.svc.us-east1-gcp.pinecone.io')
|
||||
|
||||
self.cosine_better_than_threshold = self.global_config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
# self._client.upsert(vectors=[]) pinecone
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
# self._client.query(vector=[...], top_key=10) pinecone
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
vector=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
better_than_threshold=self.cosine_better_than_threshold, ???
|
||||
)
|
||||
results = [
|
||||
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
||||
@@ -112,7 +164,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
print("self._client.save()")
|
||||
# self._client.save()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -243,3 +296,187 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
|
||||
# @staticmethod
|
||||
# def write_nx_graph(graph: nx.Graph, file_name):
|
||||
# logger.info(
|
||||
# f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
||||
# )
|
||||
# nx.write_graphml(graph, file_name)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
print ("KG successfully indexed.")
|
||||
# Neo4JStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id
|
||||
with self.driver.session() as session:
|
||||
return session.read_transaction(self._check_node_exists, entity_name_label)
|
||||
|
||||
@staticmethod
|
||||
def _check_node_exists(tx, label):
|
||||
query = f"MATCH (n:{label}) RETURN count(n) > 0 AS node_exists"
|
||||
result = tx.run(query)
|
||||
return result.single()["node_exists"]
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
entity_name_label_source = source_node_id
|
||||
entity_name_label_target = target_node_id
|
||||
#hard code relaitionship type
|
||||
with self.driver.session() as session:
|
||||
result = session.read_transaction(self._check_edge_existence, entity_name_label_source, entity_name_label_target)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _check_edge_existence(tx, label1, label2):
|
||||
query = (
|
||||
f"MATCH (a:{label1})-[r]-(b:{label2}) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
)
|
||||
result = tx.run(query)
|
||||
return result.single()["edgeExists"]
|
||||
def close(self):
|
||||
self.driver.close()
|
||||
|
||||
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
entity_name_label = node_id
|
||||
with driver.session() as session:
|
||||
result = session.run(
|
||||
"MATCH (n) WHERE n.name = $name RETURN n",
|
||||
name=node_name
|
||||
)
|
||||
|
||||
for record in result:
|
||||
return record["n"] # Return the first matching node
|
||||
|
||||
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id
|
||||
neo4j = Neo4j("bolt://localhost:7687", "neo4j", "password")
|
||||
with neo4j.driver.session() as session:
|
||||
degree = Neo4j.find_node_degree(session, entity_name_label)
|
||||
return degree
|
||||
|
||||
@staticmethod
|
||||
def find_node_degree(session, label):
|
||||
with session.begin_transaction() as tx:
|
||||
result = tx.run("MATCH (n:`{label}`) RETURN n, size((n)--()) AS degree".format(label=label))
|
||||
record = result.single()
|
||||
if record:
|
||||
return record["degree"]
|
||||
else:
|
||||
return None
|
||||
|
||||
# edge_degree
|
||||
# from neo4j import GraphDatabase
|
||||
|
||||
# driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))
|
||||
|
||||
#
|
||||
#
|
||||
# def edge_degree(tx, source_id, target_id):
|
||||
# result = tx.run("""
|
||||
# MATCH (source) WHERE ID(source) = $source_id
|
||||
# MATCH (target) WHERE ID(target) = $target_id
|
||||
# MATCH (source)-[r]-(target)
|
||||
# RETURN COUNT(r) AS degree
|
||||
# """, source_id=source_id, target_id=target_id)
|
||||
|
||||
# return result.single()["degree"]
|
||||
|
||||
# with driver.session() as session:
|
||||
# degree = session.read_transaction(get_edge_degree, 1, 2)
|
||||
# print("Degree of edge between source and target:", degree)
|
||||
|
||||
|
||||
|
||||
#get_edge
|
||||
# def get_edge(driver, node_id):
|
||||
# with driver.session() as session:
|
||||
# result = session.run(
|
||||
# """
|
||||
# MATCH (n)-[r]-(m)
|
||||
# WHERE id(n) = $node_id
|
||||
# RETURN r
|
||||
# """,
|
||||
# node_id=node_id
|
||||
# )
|
||||
# return [record["r"] for record in result]
|
||||
|
||||
# driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))
|
||||
|
||||
# edges = get_node_edges(driver, 123) # Replace 123 with the actual node ID
|
||||
|
||||
# for edge in edges:
|
||||
# print(f"Edge ID: {edge.id}, Type: {edge.type}, Start: {edge.start_node.id}, End: {edge.end_node.id}")
|
||||
|
||||
# driver.close()
|
||||
|
||||
|
||||
#upsert_node
|
||||
#add_node, upsert_node
|
||||
# async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
# node_name = node_id
|
||||
# with driver.session() as session:
|
||||
# session.run("CREATE (p:$node_name $node_data)", node_name=node_name, node_data=**node_data)
|
||||
|
||||
# with GraphDatabase.driver(URI, auth=AUTH) as driver:
|
||||
# add_node(driver, entity, data)
|
||||
|
||||
#async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
|
||||
# def add_edge_with_data(tx, source_node_id, target_node_id, relationship_type, edge_data: dict[str, str]):
|
||||
# source_node_name = source_node_id
|
||||
# target_node_name = target_node_id
|
||||
# tx.run("MATCH (s), (t) WHERE id(s) = $source_node_id AND id(t) = $target_node_id "
|
||||
# "CREATE (s)-[r:$relationship_type]->(t) SET r = $data",
|
||||
# source_node_id=source_node_id, target_node_id=target_node_id,
|
||||
# relationship_type=relationship_type, data=edge_data)
|
||||
|
||||
# with driver.session() as session:
|
||||
# session.write_transaction(add_edge_with_data, 1, 2, "KNOWS", {"since": 2020, "strength": 5})
|
||||
|
||||
|
||||
#async def _node2vec_embed(self):
|
||||
# # async def _node2vec_embed(self):
|
||||
# with driver.session() as session:
|
||||
# #Define the Cypher query
|
||||
# options = self.global_config["node2vec_params"]
|
||||
# query = f"""CALL gds.node2vec.stream('myGraph', {**options})
|
||||
# YIELD nodeId, embedding
|
||||
# RETURN nodeId, embedding"""
|
||||
# # Run the query and process the results
|
||||
# results = session.run(query)
|
||||
# for record in results:
|
||||
# node_id = record["nodeId"]
|
||||
# embedding = record["embedding"]
|
||||
# print(f"Node ID: {node_id}, Embedding: {embedding}")
|
||||
# #need to return two lists here.
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user