From 8e9005baad5a3fba1324ddc9e11060f00e9a1b29 Mon Sep 17 00:00:00 2001
From: LarFii <834462287@qq.com>
Date: Sun, 20 Oct 2024 23:08:26 +0800
Subject: [PATCH] Add visualization methods
---
.gitignore | 3 +-
README.md | 141 +++++++++++++++++-
...ph_visual.py => graph_visual_with_html.py} | 0
examples/graph_visual_with_neo4j.py | 118 +++++++++++++++
lightrag/utils.py | 49 ++++++
5 files changed, 308 insertions(+), 3 deletions(-)
rename examples/{graph_visual.py => graph_visual_with_html.py} (100%)
create mode 100644 examples/graph_visual_with_neo4j.py
diff --git a/.gitignore b/.gitignore
index edfbfbfc..5a41ae32 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,5 @@ __pycache__
dickens/
book.txt
lightrag-dev/
-.idea/
\ No newline at end of file
+.idea/
+dist/
\ No newline at end of file
diff --git a/README.md b/README.md
index c8d6e312..89e50aa0 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
## π News
+- [x] [2024.10.20]π―π―π’π’We add two methods to visualize the graph.
- [x] [2024.10.18]π―π―π’π’Weβve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
- [x] [2024.10.17]π―π―π’π’We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ππ
- [x] [2024.10.16]π―π―π’π’LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
@@ -221,7 +222,11 @@ with open("./newText.txt") as f:
### Graph Visualization
-* Generate html file
+
+ Graph visualization with html
+
+* The following code can be found in `examples/graph_visual_with_html.py`
+
```python
import networkx as nx
from pyvis.network import Network
@@ -238,6 +243,137 @@ net.from_nx(G)
# Save and display the network
net.show('knowledge_graph.html')
```
+
+
+
+
+ Graph visualization with Neo4j
+
+* The following code can be found in `examples/graph_visual_with_neo4j.py`
+
+```python
+import os
+import json
+from lightrag.utils import xml_to_json
+from neo4j import GraphDatabase
+
+# Constants
+WORKING_DIR = "./dickens"
+BATCH_SIZE_NODES = 500
+BATCH_SIZE_EDGES = 100
+
+# Neo4j connection credentials
+NEO4J_URI = "bolt://localhost:7687"
+NEO4J_USERNAME = "neo4j"
+NEO4J_PASSWORD = "your_password"
+
+def convert_xml_to_json(xml_path, output_path):
+ """Converts XML file to JSON and saves the output."""
+ if not os.path.exists(xml_path):
+ print(f"Error: File not found - {xml_path}")
+ return None
+
+ json_data = xml_to_json(xml_path)
+ if json_data:
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
+ print(f"JSON file created: {output_path}")
+ return json_data
+ else:
+ print("Failed to create JSON data")
+ return None
+
+def process_in_batches(tx, query, data, batch_size):
+ """Process data in batches and execute the given query."""
+ for i in range(0, len(data), batch_size):
+ batch = data[i:i + batch_size]
+ tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
+
+def main():
+ # Paths
+ xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
+ json_file = os.path.join(WORKING_DIR, 'graph_data.json')
+
+ # Convert XML to JSON
+ json_data = convert_xml_to_json(xml_file, json_file)
+ if json_data is None:
+ return
+
+ # Load nodes and edges
+ nodes = json_data.get('nodes', [])
+ edges = json_data.get('edges', [])
+
+ # Neo4j queries
+ create_nodes_query = """
+ UNWIND $nodes AS node
+ MERGE (e:Entity {id: node.id})
+ SET e.entity_type = node.entity_type,
+ e.description = node.description,
+ e.source_id = node.source_id,
+ e.displayName = node.id
+ REMOVE e:Entity
+ WITH e, node
+ CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
+ RETURN count(*)
+ """
+
+ create_edges_query = """
+ UNWIND $edges AS edge
+ MATCH (source {id: edge.source})
+ MATCH (target {id: edge.target})
+ WITH source, target, edge,
+ CASE
+ WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
+ WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
+ WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
+ WHEN edge.keywords CONTAINS 'located' THEN 'located'
+ WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
+ ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
+ END AS relType
+ CALL apoc.create.relationship(source, relType, {
+ weight: edge.weight,
+ description: edge.description,
+ keywords: edge.keywords,
+ source_id: edge.source_id
+ }, target) YIELD rel
+ RETURN count(*)
+ """
+
+ set_displayname_and_labels_query = """
+ MATCH (n)
+ SET n.displayName = n.id
+ WITH n
+ CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
+ RETURN count(*)
+ """
+
+ # Create a Neo4j driver
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
+
+ try:
+ # Execute queries in batches
+ with driver.session() as session:
+ # Insert nodes in batches
+ session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
+
+ # Insert edges in batches
+ session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
+
+ # Set displayName and labels
+ session.run(set_displayname_and_labels_query)
+
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+ finally:
+ driver.close()
+
+if __name__ == "__main__":
+ main()
+```
+
+
+
## Evaluation
### Dataset
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
@@ -484,8 +620,9 @@ def extract_queries(file_path):
.
βββ examples
β βββ batch_eval.py
+β βββ graph_visual_with_html.py
+β βββ graph_visual_with_neo4j.py
β βββ generate_query.py
-β βββ graph_visual.py
β βββ lightrag_azure_openai_demo.py
β βββ lightrag_bedrock_demo.py
β βββ lightrag_hf_demo.py
diff --git a/examples/graph_visual.py b/examples/graph_visual_with_html.py
similarity index 100%
rename from examples/graph_visual.py
rename to examples/graph_visual_with_html.py
diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py
new file mode 100644
index 00000000..22dde368
--- /dev/null
+++ b/examples/graph_visual_with_neo4j.py
@@ -0,0 +1,118 @@
+import os
+import json
+from lightrag.utils import xml_to_json
+from neo4j import GraphDatabase
+
+# Constants
+WORKING_DIR = "./dickens"
+BATCH_SIZE_NODES = 500
+BATCH_SIZE_EDGES = 100
+
+# Neo4j connection credentials
+NEO4J_URI = "bolt://localhost:7687"
+NEO4J_USERNAME = "neo4j"
+NEO4J_PASSWORD = "your_password"
+
+def convert_xml_to_json(xml_path, output_path):
+ """Converts XML file to JSON and saves the output."""
+ if not os.path.exists(xml_path):
+ print(f"Error: File not found - {xml_path}")
+ return None
+
+ json_data = xml_to_json(xml_path)
+ if json_data:
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
+ print(f"JSON file created: {output_path}")
+ return json_data
+ else:
+ print("Failed to create JSON data")
+ return None
+
+def process_in_batches(tx, query, data, batch_size):
+ """Process data in batches and execute the given query."""
+ for i in range(0, len(data), batch_size):
+ batch = data[i:i + batch_size]
+ tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
+
+def main():
+ # Paths
+ xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
+ json_file = os.path.join(WORKING_DIR, 'graph_data.json')
+
+ # Convert XML to JSON
+ json_data = convert_xml_to_json(xml_file, json_file)
+ if json_data is None:
+ return
+
+ # Load nodes and edges
+ nodes = json_data.get('nodes', [])
+ edges = json_data.get('edges', [])
+
+ # Neo4j queries
+ create_nodes_query = """
+ UNWIND $nodes AS node
+ MERGE (e:Entity {id: node.id})
+ SET e.entity_type = node.entity_type,
+ e.description = node.description,
+ e.source_id = node.source_id,
+ e.displayName = node.id
+ REMOVE e:Entity
+ WITH e, node
+ CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
+ RETURN count(*)
+ """
+
+ create_edges_query = """
+ UNWIND $edges AS edge
+ MATCH (source {id: edge.source})
+ MATCH (target {id: edge.target})
+ WITH source, target, edge,
+ CASE
+ WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
+ WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
+ WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
+ WHEN edge.keywords CONTAINS 'located' THEN 'located'
+ WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
+ ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
+ END AS relType
+ CALL apoc.create.relationship(source, relType, {
+ weight: edge.weight,
+ description: edge.description,
+ keywords: edge.keywords,
+ source_id: edge.source_id
+ }, target) YIELD rel
+ RETURN count(*)
+ """
+
+ set_displayname_and_labels_query = """
+ MATCH (n)
+ SET n.displayName = n.id
+ WITH n
+ CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
+ RETURN count(*)
+ """
+
+ # Create a Neo4j driver
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
+
+ try:
+ # Execute queries in batches
+ with driver.session() as session:
+ # Insert nodes in batches
+ session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
+
+ # Insert edges in batches
+ session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
+
+ # Set displayName and labels
+ session.run(set_displayname_and_labels_query)
+
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+ finally:
+ driver.close()
+
+if __name__ == "__main__":
+ main()
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 67d094c6..9a68c16b 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -8,6 +8,7 @@ from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union
+import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
@@ -183,3 +184,51 @@ def list_of_list_to_csv(data: list[list]):
def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
+
+def xml_to_json(xml_file):
+ try:
+ tree = ET.parse(xml_file)
+ root = tree.getroot()
+
+ # Print the root element's tag and attributes to confirm the file has been correctly loaded
+ print(f"Root element: {root.tag}")
+ print(f"Root attributes: {root.attrib}")
+
+ data = {
+ "nodes": [],
+ "edges": []
+ }
+
+ # Use namespace
+ namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
+
+ for node in root.findall('.//node', namespace):
+ node_data = {
+ "id": node.get('id').strip('"'),
+ "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
+ "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
+ "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
+ }
+ data["nodes"].append(node_data)
+
+ for edge in root.findall('.//edge', namespace):
+ edge_data = {
+ "source": edge.get('source').strip('"'),
+ "target": edge.get('target').strip('"'),
+ "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
+ "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
+ "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
+ "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
+ }
+ data["edges"].append(edge_data)
+
+ # Print the number of nodes and edges found
+ print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
+
+ return data
+ except ET.ParseError as e:
+ print(f"Error parsing XML file: {e}")
+ return None
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return None