Revert "Cleanup of code"

This commit is contained in:
Yannick Stephan
2025-02-20 15:09:43 +01:00
committed by GitHub
parent c431cd584a
commit 678e0f9aea
12 changed files with 84 additions and 7 deletions

View File

@@ -1682,6 +1682,11 @@ def create_app(args):
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# query all graph labels
@app.get("/graph/label/list")
async def get_graph_labels():
return await rag.get_graph_labels()
# query all graph # query all graph
@app.get("/graphs") @app.get("/graphs")
async def get_knowledge_graph(label: str): async def get_knowledge_graph(label: str):

View File

@@ -198,6 +198,10 @@ class BaseGraphStorage(StorageNameSpace, ABC):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
"""Get all labels in the graph.""" """Get all labels in the graph."""
@abstractmethod
async def get_all_labels(self) -> list[str]:
"""Get a knowledge graph of a node."""
@abstractmethod @abstractmethod
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5

View File

@@ -60,6 +60,10 @@ class AGEQueryException(Exception):
@final @final
@dataclass @dataclass
class AGEStorage(BaseGraphStorage): class AGEStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
@@ -616,6 +620,9 @@ class AGEStorage(BaseGraphStorage):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -403,6 +403,9 @@ class GremlinStorage(BaseGraphStorage):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -601,6 +601,24 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def get_all_labels(self) -> list[str]:
"""
Get all existing node _id in the database
Returns:
[id1, id2, ...] # Alphabetically sorted id list
"""
# Use MongoDB's distinct and aggregation to get all unique labels
pipeline = [
{"$group": {"_id": "$_id"}}, # Group by _id
{"$sort": {"_id": 1}}, # Sort alphabetically
]
cursor = self.collection.aggregate(pipeline)
labels = []
async for doc in cursor:
labels.append(doc["_id"])
return labels
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -628,6 +628,31 @@ class Neo4JStorage(BaseGraphStorage):
await traverse(label, 0) await traverse(label, 0)
return result return result
async def get_all_labels(self) -> list[str]:
"""
Get all existing node labels in the database
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
"""
async with self._driver.session(database=self._DATABASE) as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH (n)
WITH DISTINCT labels(n) AS node_labels
UNWIND node_labels AS label
RETURN DISTINCT label
ORDER BY label
"""
result = await session.run(query)
labels = []
async for record in result:
labels.append(record["label"])
return labels
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError

View File

@@ -168,6 +168,9 @@ class NetworkXStorage(BaseGraphStorage):
if self._graph.has_edge(source, target): if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target) self._graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -670,6 +670,9 @@ class OracleGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -178,10 +178,12 @@ class PostgreSQLDB:
asyncpg.exceptions.UniqueViolationError, asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError, asyncpg.exceptions.DuplicateTableError,
) as e: ) as e:
if not upsert: if upsert:
logger.error(f"PostgreSQL, upsert error: {e}") print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database, sql:{sql}, data:{data}, error:{e}") logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise raise
@@ -1085,6 +1087,9 @@ class PGGraphStorage(BaseGraphStorage):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -560,6 +560,9 @@ class TiDBGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -458,6 +458,10 @@ class LightRAG:
self._storages_status = StoragesStatus.FINALIZED self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages") logger.debug("Finalized Storages")
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text
async def get_knowledge_graph( async def get_knowledge_graph(
self, nodel_label: str, max_depth: int self, nodel_label: str, max_depth: int
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -1,6 +1,6 @@
from typing import Optional, Tuple, Dict, List from typing import Optional, Tuple, Dict, List
import numpy as np import numpy as np
import networkx as nx
import pipmaster as pm import pipmaster as pm
# Added automatic libraries install using pipmaster # Added automatic libraries install using pipmaster
@@ -12,10 +12,7 @@ if not pm.is_installed("pyglm"):
pm.install("pyglm") pm.install("pyglm")
if not pm.is_installed("python-louvain"): if not pm.is_installed("python-louvain"):
pm.install("python-louvain") pm.install("python-louvain")
if not pm.is_installed("networkx"):
pm.install("networkx")
import networkx as nx
import moderngl import moderngl
from imgui_bundle import imgui, immapp, hello_imgui from imgui_bundle import imgui, immapp, hello_imgui
import community import community