diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 35ef85a8..1cd2e7a3 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -1,6 +1,6 @@ import os import json -from lightrag.utils import xml_to_json +import xml.etree.ElementTree as ET from neo4j import GraphDatabase # Constants @@ -14,6 +14,66 @@ NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "your_password" +def xml_to_json(xml_file): + try: + tree = ET.parse(xml_file) + root = tree.getroot() + + # Print the root element's tag and attributes to confirm the file has been correctly loaded + print(f"Root element: {root.tag}") + print(f"Root attributes: {root.attrib}") + + data = {"nodes": [], "edges": []} + + # Use namespace + namespace = {"": "http://graphml.graphdrawing.org/xmlns"} + + for node in root.findall(".//node", namespace): + node_data = { + "id": node.get("id").strip('"'), + "entity_type": node.find("./data[@key='d1']", namespace).text.strip('"') + if node.find("./data[@key='d1']", namespace) is not None + else "", + "description": node.find("./data[@key='d2']", namespace).text + if node.find("./data[@key='d2']", namespace) is not None + else "", + "source_id": node.find("./data[@key='d3']", namespace).text + if node.find("./data[@key='d3']", namespace) is not None + else "", + } + data["nodes"].append(node_data) + + for edge in root.findall(".//edge", namespace): + edge_data = { + "source": edge.get("source").strip('"'), + "target": edge.get("target").strip('"'), + "weight": float(edge.find("./data[@key='d5']", namespace).text) + if edge.find("./data[@key='d5']", namespace) is not None + else 0.0, + "description": edge.find("./data[@key='d6']", namespace).text + if edge.find("./data[@key='d6']", namespace) is not None + else "", + "keywords": edge.find("./data[@key='d7']", namespace).text + if edge.find("./data[@key='d7']", namespace) is not None + else "", + "source_id": edge.find("./data[@key='d8']", namespace).text + if edge.find("./data[@key='d8']", namespace) is not None + else "", + } + data["edges"].append(edge_data) + + # Print the number of nodes and edges found + print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") + + return data + except ET.ParseError as e: + print(f"Error parsing XML file: {e}") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None + + def convert_xml_to_json(xml_path, output_path): """Converts XML file to JSON and saves the output.""" if not os.path.exists(xml_path): diff --git a/lightrag/utils.py b/lightrag/utils.py index 7ecb11e3..2e75b9b9 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -13,7 +13,6 @@ from dataclasses import dataclass from functools import wraps from hashlib import md5 from typing import Any, Protocol, Callable, TYPE_CHECKING, List -import xml.etree.ElementTree as ET import numpy as np from lightrag.prompt import PROMPTS from dotenv import load_dotenv @@ -753,71 +752,6 @@ def truncate_list_by_token_size( return list_data -def save_data_to_file(data, file_name): - with open(file_name, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=4) - - -def xml_to_json(xml_file): - try: - tree = ET.parse(xml_file) - root = tree.getroot() - - # Print the root element's tag and attributes to confirm the file has been correctly loaded - print(f"Root element: {root.tag}") - print(f"Root attributes: {root.attrib}") - - data = {"nodes": [], "edges": []} - - # Use namespace - namespace = {"": "http://graphml.graphdrawing.org/xmlns"} - - for node in root.findall(".//node", namespace): - node_data = { - "id": node.get("id").strip('"'), - "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') - if node.find("./data[@key='d0']", namespace) is not None - else "", - "description": node.find("./data[@key='d1']", namespace).text - if node.find("./data[@key='d1']", namespace) is not None - else "", - "source_id": node.find("./data[@key='d2']", namespace).text - if node.find("./data[@key='d2']", namespace) is not None - else "", - } - data["nodes"].append(node_data) - - for edge in root.findall(".//edge", namespace): - edge_data = { - "source": edge.get("source").strip('"'), - "target": edge.get("target").strip('"'), - "weight": float(edge.find("./data[@key='d3']", namespace).text) - if edge.find("./data[@key='d3']", namespace) is not None - else 0.0, - "description": edge.find("./data[@key='d4']", namespace).text - if edge.find("./data[@key='d4']", namespace) is not None - else "", - "keywords": edge.find("./data[@key='d5']", namespace).text - if edge.find("./data[@key='d5']", namespace) is not None - else "", - "source_id": edge.find("./data[@key='d6']", namespace).text - if edge.find("./data[@key='d6']", namespace) is not None - else "", - } - data["edges"].append(edge_data) - - # Print the number of nodes and edges found - print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") - - return data - except ET.ParseError as e: - print(f"Error parsing XML file: {e}") - return None - except Exception as e: - print(f"An error occurred: {e}") - return None - - def process_combine_contexts(*context_lists): """ Combine multiple context lists and remove duplicate content