From 4ef705b13ff7ac25905af9a479086d91d28eec28 Mon Sep 17 00:00:00 2001 From: Ken Wiltshire Date: Fri, 25 Oct 2024 11:28:41 -0400 Subject: [PATCH] adding neo4j integration --- lightrag/storage.py | 243 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 240 insertions(+), 3 deletions(-) diff --git a/lightrag/storage.py b/lightrag/storage.py index 1f22fc56..704dc4e8 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -97,14 +97,66 @@ class NanoVectorDBStorage(BaseVectorStorage): d["__vector__"] = embeddings[i] results = self._client.upsert(datas=list_data) return results + + +@dataclass +class PineConeVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + self._client_file_name = os.path.join( + self.global_config["working_dir"], f"vdb_{self.namespace}.json" + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) + import os + from pinecone import Pinecone + + pc = Pinecone() #api_key=os.environ.get('PINECONE_API_KEY')) + # From here on, everything is identical to the REST-based SDK. + self._client = pc.Index(host=self._client_pinecone_host)#'my-index-8833ca1.svc.us-east1-gcp.pinecone.io') + + self.cosine_better_than_threshold = self.global_config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + list_data = [ + { + "__id__": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + # self._client.upsert(vectors=[]) pinecone + results = self._client.upsert(datas=list_data) + return results async def query(self, query: str, top_k=5): embedding = await self.embedding_func([query]) embedding = embedding[0] + # self._client.query(vector=[...], top_key=10) pinecone results = self._client.query( - query=embedding, + vector=embedding, top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, + better_than_threshold=self.cosine_better_than_threshold, ??? ) results = [ {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results @@ -112,7 +164,8 @@ class NanoVectorDBStorage(BaseVectorStorage): return results async def index_done_callback(self): - self._client.save() + print("self._client.save()") + # self._client.save() @dataclass @@ -243,3 +296,187 @@ class NetworkXStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids + + +@dataclass +class Neo4JStorage(BaseGraphStorage): + @staticmethod + def load_nx_graph(file_name) -> nx.Graph: + if os.path.exists(file_name): + return nx.read_graphml(file_name) + return None + + # @staticmethod + # def write_nx_graph(graph: nx.Graph, file_name): + # logger.info( + # f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" + # ) + # nx.write_graphml(graph, file_name) + + + def __post_init__(self): + self._graphml_xml_file = os.path.join( + self.global_config["working_dir"], f"graph_{self.namespace}.graphml" + ) + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + if preloaded_graph is not None: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + self._graph = preloaded_graph or nx.Graph() + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def index_done_callback(self): + print ("KG successfully indexed.") + # Neo4JStorage.write_nx_graph(self._graph, self._graphml_xml_file) + async def has_node(self, node_id: str) -> bool: + entity_name_label = node_id + with self.driver.session() as session: + return session.read_transaction(self._check_node_exists, entity_name_label) + + @staticmethod + def _check_node_exists(tx, label): + query = f"MATCH (n:{label}) RETURN count(n) > 0 AS node_exists" + result = tx.run(query) + return result.single()["node_exists"] + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + entity_name_label_source = source_node_id + entity_name_label_target = target_node_id + #hard code relaitionship type + with self.driver.session() as session: + result = session.read_transaction(self._check_edge_existence, entity_name_label_source, entity_name_label_target) + return result + + @staticmethod + def _check_edge_existence(tx, label1, label2): + query = ( + f"MATCH (a:{label1})-[r]-(b:{label2}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = tx.run(query) + return result.single()["edgeExists"] + def close(self): + self.driver.close() + + + + async def get_node(self, node_id: str) -> Union[dict, None]: + entity_name_label = node_id + with driver.session() as session: + result = session.run( + "MATCH (n) WHERE n.name = $name RETURN n", + name=node_name + ) + + for record in result: + return record["n"] # Return the first matching node + + + + async def node_degree(self, node_id: str) -> int: + entity_name_label = node_id + neo4j = Neo4j("bolt://localhost:7687", "neo4j", "password") + with neo4j.driver.session() as session: + degree = Neo4j.find_node_degree(session, entity_name_label) + return degree + + @staticmethod + def find_node_degree(session, label): + with session.begin_transaction() as tx: + result = tx.run("MATCH (n:`{label}`) RETURN n, size((n)--()) AS degree".format(label=label)) + record = result.single() + if record: + return record["degree"] + else: + return None + +# edge_degree + # from neo4j import GraphDatabase + + # driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) + + # + # + # def edge_degree(tx, source_id, target_id): + # result = tx.run(""" + # MATCH (source) WHERE ID(source) = $source_id + # MATCH (target) WHERE ID(target) = $target_id + # MATCH (source)-[r]-(target) + # RETURN COUNT(r) AS degree + # """, source_id=source_id, target_id=target_id) + + # return result.single()["degree"] + + # with driver.session() as session: + # degree = session.read_transaction(get_edge_degree, 1, 2) + # print("Degree of edge between source and target:", degree) + + + + #get_edge + # def get_edge(driver, node_id): + # with driver.session() as session: + # result = session.run( + # """ + # MATCH (n)-[r]-(m) + # WHERE id(n) = $node_id + # RETURN r + # """, + # node_id=node_id + # ) + # return [record["r"] for record in result] + + # driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) + + # edges = get_node_edges(driver, 123) # Replace 123 with the actual node ID + + # for edge in edges: + # print(f"Edge ID: {edge.id}, Type: {edge.type}, Start: {edge.start_node.id}, End: {edge.end_node.id}") + + # driver.close() + + +#upsert_node + #add_node, upsert_node + # async def upsert_node(self, node_id: str, node_data: dict[str, str]): + # node_name = node_id + # with driver.session() as session: + # session.run("CREATE (p:$node_name $node_data)", node_name=node_name, node_data=**node_data) + + # with GraphDatabase.driver(URI, auth=AUTH) as driver: + # add_node(driver, entity, data) + +#async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]): + # def add_edge_with_data(tx, source_node_id, target_node_id, relationship_type, edge_data: dict[str, str]): + # source_node_name = source_node_id + # target_node_name = target_node_id + # tx.run("MATCH (s), (t) WHERE id(s) = $source_node_id AND id(t) = $target_node_id " + # "CREATE (s)-[r:$relationship_type]->(t) SET r = $data", + # source_node_id=source_node_id, target_node_id=target_node_id, + # relationship_type=relationship_type, data=edge_data) + + # with driver.session() as session: + # session.write_transaction(add_edge_with_data, 1, 2, "KNOWS", {"since": 2020, "strength": 5}) + + +#async def _node2vec_embed(self): + # # async def _node2vec_embed(self): + # with driver.session() as session: + # #Define the Cypher query + # options = self.global_config["node2vec_params"] + # query = f"""CALL gds.node2vec.stream('myGraph', {**options}) + # YIELD nodeId, embedding + # RETURN nodeId, embedding""" + # # Run the query and process the results + # results = session.run(query) + # for record in results: + # node_id = record["nodeId"] + # embedding = record["embedding"] + # print(f"Node ID: {node_id}, Embedding: {embedding}") + # #need to return two lists here. + + +