Add visualization methods

This commit is contained in:
LarFii
2024-10-20 23:08:26 +08:00
parent 6b9430221e
commit bcd81e3a50
5 changed files with 308 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union
import xml.etree.ElementTree as ET
import numpy as np
import tiktoken
@@ -183,3 +184,51 @@ def list_of_list_to_csv(data: list[list]):
def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Print the root element's tag and attributes to confirm the file has been correctly loaded
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
data = {
"nodes": [],
"edges": []
}
# Use namespace
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
for node in root.findall('.//node', namespace):
node_data = {
"id": node.get('id').strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
}
data["nodes"].append(node_data)
for edge in root.findall('.//edge', namespace):
edge_data = {
"source": edge.get('source').strip('"'),
"target": edge.get('target').strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
}
data["edges"].append(edge_data)
# Print the number of nodes and edges found
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None