Add visualization methods
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,3 +4,4 @@ dickens/
|
|||||||
book.txt
|
book.txt
|
||||||
lightrag-dev/
|
lightrag-dev/
|
||||||
.idea/
|
.idea/
|
||||||
|
dist/
|
141
README.md
141
README.md
@@ -22,6 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 🎉 News
|
## 🎉 News
|
||||||
|
- [x] [2024.10.20]🎯🎯📢📢We add two methods to visualize the graph.
|
||||||
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
||||||
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
|
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
|
||||||
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||||
@@ -221,7 +222,11 @@ with open("./newText.txt") as f:
|
|||||||
|
|
||||||
### Graph Visualization
|
### Graph Visualization
|
||||||
|
|
||||||
* Generate html file
|
<details>
|
||||||
|
<summary> Graph visualization with html </summary>
|
||||||
|
|
||||||
|
* The following code can be found in `examples/graph_visual_with_html.py`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pyvis.network import Network
|
from pyvis.network import Network
|
||||||
@@ -238,6 +243,137 @@ net.from_nx(G)
|
|||||||
# Save and display the network
|
# Save and display the network
|
||||||
net.show('knowledge_graph.html')
|
net.show('knowledge_graph.html')
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Graph visualization with Neo4j </summary>
|
||||||
|
|
||||||
|
* The following code can be found in `examples/graph_visual_with_neo4j.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from lightrag.utils import xml_to_json
|
||||||
|
from neo4j import GraphDatabase
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
BATCH_SIZE_NODES = 500
|
||||||
|
BATCH_SIZE_EDGES = 100
|
||||||
|
|
||||||
|
# Neo4j connection credentials
|
||||||
|
NEO4J_URI = "bolt://localhost:7687"
|
||||||
|
NEO4J_USERNAME = "neo4j"
|
||||||
|
NEO4J_PASSWORD = "your_password"
|
||||||
|
|
||||||
|
def convert_xml_to_json(xml_path, output_path):
|
||||||
|
"""Converts XML file to JSON and saves the output."""
|
||||||
|
if not os.path.exists(xml_path):
|
||||||
|
print(f"Error: File not found - {xml_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_data = xml_to_json(xml_path)
|
||||||
|
if json_data:
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"JSON file created: {output_path}")
|
||||||
|
return json_data
|
||||||
|
else:
|
||||||
|
print("Failed to create JSON data")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_in_batches(tx, query, data, batch_size):
|
||||||
|
"""Process data in batches and execute the given query."""
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i + batch_size]
|
||||||
|
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Paths
|
||||||
|
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
|
||||||
|
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
|
||||||
|
|
||||||
|
# Convert XML to JSON
|
||||||
|
json_data = convert_xml_to_json(xml_file, json_file)
|
||||||
|
if json_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load nodes and edges
|
||||||
|
nodes = json_data.get('nodes', [])
|
||||||
|
edges = json_data.get('edges', [])
|
||||||
|
|
||||||
|
# Neo4j queries
|
||||||
|
create_nodes_query = """
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (e:Entity {id: node.id})
|
||||||
|
SET e.entity_type = node.entity_type,
|
||||||
|
e.description = node.description,
|
||||||
|
e.source_id = node.source_id,
|
||||||
|
e.displayName = node.id
|
||||||
|
REMOVE e:Entity
|
||||||
|
WITH e, node
|
||||||
|
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
create_edges_query = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (source {id: edge.source})
|
||||||
|
MATCH (target {id: edge.target})
|
||||||
|
WITH source, target, edge,
|
||||||
|
CASE
|
||||||
|
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
|
||||||
|
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
|
||||||
|
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
|
||||||
|
WHEN edge.keywords CONTAINS 'located' THEN 'located'
|
||||||
|
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
|
||||||
|
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
|
||||||
|
END AS relType
|
||||||
|
CALL apoc.create.relationship(source, relType, {
|
||||||
|
weight: edge.weight,
|
||||||
|
description: edge.description,
|
||||||
|
keywords: edge.keywords,
|
||||||
|
source_id: edge.source_id
|
||||||
|
}, target) YIELD rel
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_displayname_and_labels_query = """
|
||||||
|
MATCH (n)
|
||||||
|
SET n.displayName = n.id
|
||||||
|
WITH n
|
||||||
|
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a Neo4j driver
|
||||||
|
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute queries in batches
|
||||||
|
with driver.session() as session:
|
||||||
|
# Insert nodes in batches
|
||||||
|
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
|
||||||
|
|
||||||
|
# Insert edges in batches
|
||||||
|
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
|
||||||
|
|
||||||
|
# Set displayName and labels
|
||||||
|
session.run(set_displayname_and_labels_query)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
driver.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Evaluation
|
## Evaluation
|
||||||
### Dataset
|
### Dataset
|
||||||
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
||||||
@@ -484,8 +620,9 @@ def extract_queries(file_path):
|
|||||||
.
|
.
|
||||||
├── examples
|
├── examples
|
||||||
│ ├── batch_eval.py
|
│ ├── batch_eval.py
|
||||||
|
│ ├── graph_visual_with_html.py
|
||||||
|
│ ├── graph_visual_with_neo4j.py
|
||||||
│ ├── generate_query.py
|
│ ├── generate_query.py
|
||||||
│ ├── graph_visual.py
|
|
||||||
│ ├── lightrag_azure_openai_demo.py
|
│ ├── lightrag_azure_openai_demo.py
|
||||||
│ ├── lightrag_bedrock_demo.py
|
│ ├── lightrag_bedrock_demo.py
|
||||||
│ ├── lightrag_hf_demo.py
|
│ ├── lightrag_hf_demo.py
|
||||||
|
118
examples/graph_visual_with_neo4j.py
Normal file
118
examples/graph_visual_with_neo4j.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from lightrag.utils import xml_to_json
|
||||||
|
from neo4j import GraphDatabase
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
BATCH_SIZE_NODES = 500
|
||||||
|
BATCH_SIZE_EDGES = 100
|
||||||
|
|
||||||
|
# Neo4j connection credentials
|
||||||
|
NEO4J_URI = "bolt://localhost:7687"
|
||||||
|
NEO4J_USERNAME = "neo4j"
|
||||||
|
NEO4J_PASSWORD = "your_password"
|
||||||
|
|
||||||
|
def convert_xml_to_json(xml_path, output_path):
|
||||||
|
"""Converts XML file to JSON and saves the output."""
|
||||||
|
if not os.path.exists(xml_path):
|
||||||
|
print(f"Error: File not found - {xml_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_data = xml_to_json(xml_path)
|
||||||
|
if json_data:
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"JSON file created: {output_path}")
|
||||||
|
return json_data
|
||||||
|
else:
|
||||||
|
print("Failed to create JSON data")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_in_batches(tx, query, data, batch_size):
|
||||||
|
"""Process data in batches and execute the given query."""
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i + batch_size]
|
||||||
|
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Paths
|
||||||
|
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
|
||||||
|
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
|
||||||
|
|
||||||
|
# Convert XML to JSON
|
||||||
|
json_data = convert_xml_to_json(xml_file, json_file)
|
||||||
|
if json_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load nodes and edges
|
||||||
|
nodes = json_data.get('nodes', [])
|
||||||
|
edges = json_data.get('edges', [])
|
||||||
|
|
||||||
|
# Neo4j queries
|
||||||
|
create_nodes_query = """
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (e:Entity {id: node.id})
|
||||||
|
SET e.entity_type = node.entity_type,
|
||||||
|
e.description = node.description,
|
||||||
|
e.source_id = node.source_id,
|
||||||
|
e.displayName = node.id
|
||||||
|
REMOVE e:Entity
|
||||||
|
WITH e, node
|
||||||
|
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
create_edges_query = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (source {id: edge.source})
|
||||||
|
MATCH (target {id: edge.target})
|
||||||
|
WITH source, target, edge,
|
||||||
|
CASE
|
||||||
|
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
|
||||||
|
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
|
||||||
|
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
|
||||||
|
WHEN edge.keywords CONTAINS 'located' THEN 'located'
|
||||||
|
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
|
||||||
|
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
|
||||||
|
END AS relType
|
||||||
|
CALL apoc.create.relationship(source, relType, {
|
||||||
|
weight: edge.weight,
|
||||||
|
description: edge.description,
|
||||||
|
keywords: edge.keywords,
|
||||||
|
source_id: edge.source_id
|
||||||
|
}, target) YIELD rel
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_displayname_and_labels_query = """
|
||||||
|
MATCH (n)
|
||||||
|
SET n.displayName = n.id
|
||||||
|
WITH n
|
||||||
|
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a Neo4j driver
|
||||||
|
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute queries in batches
|
||||||
|
with driver.session() as session:
|
||||||
|
# Insert nodes in batches
|
||||||
|
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
|
||||||
|
|
||||||
|
# Insert edges in batches
|
||||||
|
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
|
||||||
|
|
||||||
|
# Set displayName and labels
|
||||||
|
session.run(set_displayname_and_labels_query)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
driver.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -8,6 +8,7 @@ from dataclasses import dataclass
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@@ -183,3 +184,51 @@ def list_of_list_to_csv(data: list[list]):
|
|||||||
def save_data_to_file(data, file_name):
|
def save_data_to_file(data, file_name):
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
def xml_to_json(xml_file):
|
||||||
|
try:
|
||||||
|
tree = ET.parse(xml_file)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
# Print the root element's tag and attributes to confirm the file has been correctly loaded
|
||||||
|
print(f"Root element: {root.tag}")
|
||||||
|
print(f"Root attributes: {root.attrib}")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"nodes": [],
|
||||||
|
"edges": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use namespace
|
||||||
|
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
|
||||||
|
|
||||||
|
for node in root.findall('.//node', namespace):
|
||||||
|
node_data = {
|
||||||
|
"id": node.get('id').strip('"'),
|
||||||
|
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
|
||||||
|
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
|
||||||
|
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
|
||||||
|
}
|
||||||
|
data["nodes"].append(node_data)
|
||||||
|
|
||||||
|
for edge in root.findall('.//edge', namespace):
|
||||||
|
edge_data = {
|
||||||
|
"source": edge.get('source').strip('"'),
|
||||||
|
"target": edge.get('target').strip('"'),
|
||||||
|
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
|
||||||
|
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
|
||||||
|
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
|
||||||
|
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
|
||||||
|
}
|
||||||
|
data["edges"].append(edge_data)
|
||||||
|
|
||||||
|
# Print the number of nodes and edges found
|
||||||
|
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
||||||
|
|
||||||
|
return data
|
||||||
|
except ET.ParseError as e:
|
||||||
|
print(f"Error parsing XML file: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
return None
|
||||||
|
Reference in New Issue
Block a user