updated clean of what implemented on DocStatusStorage
This commit is contained in:
@@ -92,22 +92,20 @@ class StorageNameSpace:
|
||||
class BaseVectorStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Query the vector storage and retrieve top_k results."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Use 'content' field from value for embedding, use key as id.
|
||||
If embedding_func is None, use 'embedding' field from value
|
||||
"""
|
||||
"""Insert or update vectors in the storage."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
"""Delete a single entity by its name."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
"""Delete relations for a given entity."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -116,9 +114,11 @@ class BaseKVStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get value by id"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get values by ids"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
@@ -126,9 +126,11 @@ class BaseKVStorage(StorageNameSpace):
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Upsert data"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -138,74 +140,62 @@ class BaseGraphStorage(StorageNameSpace):
|
||||
"""Check if a node exists in the graph."""
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""Check if an edge exists in the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Check if an edge exists in the graph."""
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""Get the degree of a node."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get the degree of a node."""
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""Get the degree of an edge."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get the degree of an edge."""
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""Get a node by its id."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get a node by its id."""
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get an edge by its source and target node ids."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get an edge by its source and target node ids."""
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
"""Get all edges connected to a node."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get all edges connected to a node."""
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""Upsert a node into the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Upsert a node into the graph."""
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""Upsert an edge into the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Upsert an edge into the graph."""
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""Delete a node from the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Delete a node from the graph."""
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Embed nodes using an algorithm."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Embed nodes using an algorithm."""
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
"""Get all labels in the graph."""
|
||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||
|
||||
"""Get all labels in the graph."""
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""Get a knowledge graph of a node."""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get a knowledge graph of a node."""
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@@ -5,7 +5,8 @@ import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("psycopg-pool"):
|
||||
@@ -15,6 +16,7 @@ if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
@@ -396,7 +398,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edge_exists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = """
|
||||
MATCH (n:`{label}`) RETURN n
|
||||
@@ -454,17 +456,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
|
||||
Args:
|
||||
source_node_label (str): Label of the source nodes
|
||||
target_node_label (str): Label of the target nodes
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
@@ -488,7 +480,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: List of dictionaries containing edge information
|
||||
@@ -526,7 +518,7 @@ class AGEStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((AGEQueryException,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Upsert a node in the AGE database.
|
||||
|
||||
@@ -562,8 +554,8 @@ class AGEStorage(BaseGraphStorage):
|
||||
retry=retry_if_exception_type((AGEQueryException,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
@@ -619,3 +611,15 @@ class AGEStorage(BaseGraphStorage):
|
||||
yield connection
|
||||
finally:
|
||||
await self._driver.putconn(connection)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
@@ -3,7 +3,9 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
@@ -15,6 +17,7 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.utils import logger
|
||||
|
||||
from ..base import BaseGraphStorage
|
||||
@@ -190,7 +193,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
return result[0]["has_edge"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
entity_name = GremlinStorage._fix_name(node_id)
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
@@ -252,17 +255,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given names
|
||||
|
||||
Args:
|
||||
source_node_id (str): Name of the source nodes
|
||||
target_node_id (str): Name of the target nodes
|
||||
|
||||
Returns:
|
||||
dict|None: Dict of found edge properties, or None if not found
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||
query = f"""g
|
||||
@@ -286,11 +279,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
)
|
||||
return edge_properties
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its name.
|
||||
:return: List of tuples containing edge sources and targets
|
||||
"""
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
node_name = GremlinStorage._fix_name(source_node_id)
|
||||
query = f"""g
|
||||
.E()
|
||||
@@ -316,7 +305,7 @@ class GremlinStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((GremlinServerError,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Upsert a node in the Gremlin graph.
|
||||
|
||||
@@ -357,8 +346,8 @@ class GremlinStorage(BaseGraphStorage):
|
||||
retry=retry_if_exception_type((GremlinServerError,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their names.
|
||||
|
||||
@@ -397,3 +386,17 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
@@ -12,7 +12,7 @@ if not pm.is_installed("pymongo"):
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
from typing import Any, List, Union
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
@@ -448,7 +448,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Return the full node document (including "edges"), or None if missing.
|
||||
"""
|
||||
@@ -456,11 +456,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Return the first edge dict from source_node_id to target_node_id if it exists.
|
||||
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
{
|
||||
@@ -486,9 +482,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
return e
|
||||
return None
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[List[Tuple[str, str]], None]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||
@@ -522,7 +516,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Insert or update a node document. If new, create an empty edges array.
|
||||
"""
|
||||
@@ -532,8 +526,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
||||
If an edge with the same target exists, we remove it and re-insert with updated data.
|
||||
@@ -559,7 +553,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
1) Remove node's doc entirely.
|
||||
2) Remove inbound edges from any doc that references node_id.
|
||||
@@ -576,7 +570,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
"""
|
||||
Placeholder for demonstration, raises NotImplementedError.
|
||||
"""
|
||||
@@ -606,9 +600,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
labels.append(doc["_id"])
|
||||
return labels
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
||||
|
@@ -3,7 +3,8 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, Tuple, List, Dict
|
||||
from typing import Any, List, Dict
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
@@ -191,7 +192,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edgeExists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier.
|
||||
|
||||
Args:
|
||||
@@ -252,17 +253,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""Find edge between two nodes identified by their labels.
|
||||
) -> dict[str, str] | None:
|
||||
|
||||
Args:
|
||||
source_node_id (str): Label of the source node
|
||||
target_node_id (str): Label of the target node
|
||||
|
||||
Returns:
|
||||
dict: Edge properties if found, with at least {"weight": 0.0}
|
||||
None: If error occurs
|
||||
"""
|
||||
try:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
@@ -321,7 +313,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
node_label = source_node_id.strip('"')
|
||||
|
||||
"""
|
||||
@@ -364,7 +356,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Upsert a node in the Neo4j database.
|
||||
|
||||
@@ -405,8 +397,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
@@ -444,9 +436,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
||||
@@ -603,7 +593,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
await traverse(label, 0)
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> List[str]:
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all existing node labels in the database
|
||||
Returns:
|
||||
@@ -627,3 +617,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async for record in result:
|
||||
labels.append(record["label"])
|
||||
return labels
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
@@ -51,11 +51,12 @@ Usage:
|
||||
import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
)
|
||||
@@ -142,7 +143,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
@@ -151,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
return self._graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
@@ -162,35 +163,30 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
) -> dict[str, str] | None:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete a node from the graph based on the specified node_id.
|
||||
|
||||
:param node_id: The node_id to delete
|
||||
"""
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
if self._graph.has_node(node_id):
|
||||
self._graph.remove_node(node_id)
|
||||
logger.info(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
@@ -226,3 +222,9 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
for source, target in edges:
|
||||
if self._graph.has_edge(source, target):
|
||||
self._graph.remove_edge(source, target)
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
@@ -13,6 +13,7 @@ if not pm.is_installed("oracledb"):
|
||||
pm.install("oracledb")
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
import oracledb
|
||||
|
||||
from ..base import (
|
||||
@@ -378,9 +379,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
#################### insert method ################
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
"""插入或更新节点"""
|
||||
# print("go into upsert node method")
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
@@ -413,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
"""插入或更新边"""
|
||||
# print("go into upsert edge method")
|
||||
source_name = source_node_id
|
||||
@@ -453,8 +452,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
await self.db.execute(merge_sql, data)
|
||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
"""为节点生成向量"""
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
@@ -471,7 +469,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
"""写入graphhml图文件"""
|
||||
logger.info(
|
||||
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
||||
@@ -493,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
return False
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""根据源和目标节点id检查边是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_edge"]
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
@@ -510,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
return False
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""根据节点id获取节点的度"""
|
||||
SQL = SQL_TEMPLATES["node_degree"]
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(SQL)
|
||||
@@ -528,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
# print("Edge degree",degree)
|
||||
return degree
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""根据节点id获取节点数据"""
|
||||
SQL = SQL_TEMPLATES["get_node"]
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
@@ -544,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""根据源和目标节点id获取边"""
|
||||
) -> dict[str, str] | None:
|
||||
SQL = SQL_TEMPLATES["get_edge"]
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
@@ -560,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
"""根据节点id获取节点的所有边"""
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
if await self.has_node(source_node_id):
|
||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||
@@ -597,6 +591,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
if res:
|
||||
return res
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
N_T = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
@@ -4,11 +4,13 @@ import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
@@ -835,7 +837,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edge_exists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
@@ -890,17 +892,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
|
||||
Args:
|
||||
source_node_id (str): Label of the source nodes
|
||||
target_node_id (str): Label of the target nodes
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
) -> dict[str, str] | None:
|
||||
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
||||
|
||||
@@ -924,7 +916,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: List of dictionaries containing edge information
|
||||
@@ -972,14 +964,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
"""
|
||||
Upsert a node in the AGE database.
|
||||
|
||||
Args:
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
properties = node_data
|
||||
|
||||
@@ -1010,8 +995,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
@@ -1053,6 +1038,19 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
NAMESPACE_TABLE_MAP = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
@@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"):
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from sqlalchemy import create_engine, text
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
#################### upsert method ################
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
@@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
@@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
}
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
@@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||
return degree
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
sql = SQL_TEMPLATES["get_node"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
) -> dict[str, str] | None:
|
||||
sql = SQL_TEMPLATES["get_edge"]
|
||||
param = {
|
||||
"source_name": source_node_id,
|
||||
@@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
sql = SQL_TEMPLATES["get_node_edges"]
|
||||
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
||||
res = await self.db.query(sql, param, multirows=True)
|
||||
@@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
return []
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
N_T = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
Reference in New Issue
Block a user