diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 35ef85a8..637cb36d 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -1,6 +1,6 @@ import os import json -from lightrag.utils import xml_to_json +import xml.etree.ElementTree as ET from neo4j import GraphDatabase # Constants @@ -13,6 +13,66 @@ NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "your_password" +def xml_to_json(xml_file): + try: + tree = ET.parse(xml_file) + root = tree.getroot() + + # Print the root element's tag and attributes to confirm the file has been correctly loaded + print(f"Root element: {root.tag}") + print(f"Root attributes: {root.attrib}") + + data = {"nodes": [], "edges": []} + + # Use namespace + namespace = {"": "http://graphml.graphdrawing.org/xmlns"} + + for node in root.findall(".//node", namespace): + node_data = { + "id": node.get("id").strip('"'), + "entity_type": node.find("./data[@key='d1']", namespace).text.strip('"') + if node.find("./data[@key='d1']", namespace) is not None + else "", + "description": node.find("./data[@key='d2']", namespace).text + if node.find("./data[@key='d2']", namespace) is not None + else "", + "source_id": node.find("./data[@key='d3']", namespace).text + if node.find("./data[@key='d3']", namespace) is not None + else "", + } + data["nodes"].append(node_data) + + for edge in root.findall(".//edge", namespace): + edge_data = { + "source": edge.get("source").strip('"'), + "target": edge.get("target").strip('"'), + "weight": float(edge.find("./data[@key='d5']", namespace).text) + if edge.find("./data[@key='d5']", namespace) is not None + else 0.0, + "description": edge.find("./data[@key='d6']", namespace).text + if edge.find("./data[@key='d6']", namespace) is not None + else "", + "keywords": edge.find("./data[@key='d7']", namespace).text + if edge.find("./data[@key='d7']", namespace) is not None + else "", + "source_id": edge.find("./data[@key='d8']", namespace).text + if edge.find("./data[@key='d8']", namespace) is not None + else "", + } + data["edges"].append(edge_data) + + + # Print the number of nodes and edges found + print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") + + return data + except ET.ParseError as e: + print(f"Error parsing XML file: {e}") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None + def convert_xml_to_json(xml_path, output_path): """Converts XML file to JSON and saves the output.""" diff --git a/lightrag/utils.py b/lightrag/utils.py index 7ecb11e3..4a3378db 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -13,7 +13,6 @@ from dataclasses import dataclass from functools import wraps from hashlib import md5 from typing import Any, Protocol, Callable, TYPE_CHECKING, List -import xml.etree.ElementTree as ET import numpy as np from lightrag.prompt import PROMPTS from dotenv import load_dotenv