updated clean of what implemented on DocStatusStorage
This commit is contained in:
@@ -92,22 +92,20 @@ class StorageNameSpace:
|
|||||||
class BaseVectorStorage(StorageNameSpace):
|
class BaseVectorStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
meta_fields: set[str] = field(default_factory=set)
|
meta_fields: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||||
|
"""Query the vector storage and retrieve top_k results."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""Use 'content' field from value for embedding, use key as id.
|
"""Insert or update vectors in the storage."""
|
||||||
If embedding_func is None, use 'embedding' field from value
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_entity(self, entity_name: str) -> None:
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
"""Delete a single entity by its name"""
|
"""Delete a single entity by its name."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
"""Delete relations for a given entity by scanning metadata"""
|
"""Delete relations for a given entity."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -116,9 +114,11 @@ class BaseKVStorage(StorageNameSpace):
|
|||||||
embedding_func: EmbeddingFunc | None = None
|
embedding_func: EmbeddingFunc | None = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
|
"""Get value by id"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
|
"""Get values by ids"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
@@ -126,9 +126,11 @@ class BaseKVStorage(StorageNameSpace):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
"""Upsert data"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
|
"""Drop the storage"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -138,74 +140,62 @@ class BaseGraphStorage(StorageNameSpace):
|
|||||||
"""Check if a node exists in the graph."""
|
"""Check if a node exists in the graph."""
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Check if an edge exists in the graph."""
|
"""Check if an edge exists in the graph."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Get the degree of a node."""
|
"""Get the degree of a node."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Get the degree of an edge."""
|
"""Get the degree of an edge."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Get a node by its id."""
|
"""Get a node by its id."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Get an edge by its source and target node ids."""
|
"""Get an edge by its source and target node ids."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
|
"""Get all edges connected to a node."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
"""Get all edges connected to a node."""
|
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Upsert a node into the graph."""
|
"""Upsert a node into the graph."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Upsert an edge into the graph."""
|
"""Upsert an edge into the graph."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Delete a node from the graph."""
|
"""Delete a node from the graph."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
"""Embed nodes using an algorithm."""
|
"""Embed nodes using an algorithm."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def embed_nodes(
|
async def embed_nodes(
|
||||||
self, algorithm: str
|
self, algorithm: str
|
||||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
"""Get all labels in the graph."""
|
||||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||||
|
|
||||||
"""Get all labels in the graph."""
|
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
"""Get a knowledge graph of a node."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
"""Get a knowledge graph of a node."""
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
|
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||||
async def get_knowledge_graph(
|
|
||||||
self, node_label: str, max_depth: int = 5
|
|
||||||
) -> KnowledgeGraph:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@@ -5,7 +5,8 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||||
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("psycopg-pool"):
|
if not pm.is_installed("psycopg-pool"):
|
||||||
@@ -15,6 +16,7 @@ if not pm.is_installed("asyncpg"):
|
|||||||
pm.install("asyncpg")
|
pm.install("asyncpg")
|
||||||
|
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
import psycopg
|
import psycopg
|
||||||
from psycopg.rows import namedtuple_row
|
from psycopg.rows import namedtuple_row
|
||||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||||
@@ -396,7 +398,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return single_result["edge_exists"]
|
return single_result["edge_exists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:`{label}`) RETURN n
|
MATCH (n:`{label}`) RETURN n
|
||||||
@@ -454,17 +456,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_label (str): Label of the source nodes
|
|
||||||
target_node_label (str): Label of the target nodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of all relationships/edges found
|
|
||||||
"""
|
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
|
|
||||||
@@ -488,7 +480,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
"""
|
||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
:return: List of dictionaries containing edge information
|
:return: List of dictionaries containing edge information
|
||||||
@@ -526,7 +518,7 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((AGEQueryException,)),
|
retry=retry_if_exception_type((AGEQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert a node in the AGE database.
|
Upsert a node in the AGE database.
|
||||||
|
|
||||||
@@ -562,8 +554,8 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
retry=retry_if_exception_type((AGEQueryException,)),
|
retry=retry_if_exception_type((AGEQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their labels.
|
||||||
|
|
||||||
@@ -619,3 +611,15 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
yield connection
|
yield connection
|
||||||
finally:
|
finally:
|
||||||
await self._driver.putconn(connection)
|
await self._driver.putconn(connection)
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
@@ -3,7 +3,9 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from gremlin_python.driver import client, serializer
|
from gremlin_python.driver import client, serializer
|
||||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||||
@@ -15,6 +17,7 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
@@ -190,7 +193,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
return result[0]["has_edge"]
|
return result[0]["has_edge"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
entity_name = GremlinStorage._fix_name(node_id)
|
entity_name = GremlinStorage._fix_name(node_id)
|
||||||
query = f"""g
|
query = f"""g
|
||||||
.V().has('graph', {self.graph_name})
|
.V().has('graph', {self.graph_name})
|
||||||
@@ -252,17 +255,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given names
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_id (str): Name of the source nodes
|
|
||||||
target_node_id (str): Name of the target nodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict|None: Dict of found edge properties, or None if not found
|
|
||||||
"""
|
|
||||||
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||||
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||||
query = f"""g
|
query = f"""g
|
||||||
@@ -286,11 +279,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return edge_properties
|
return edge_properties
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
|
||||||
Retrieves all edges (relationships) for a particular node identified by its name.
|
|
||||||
:return: List of tuples containing edge sources and targets
|
|
||||||
"""
|
|
||||||
node_name = GremlinStorage._fix_name(source_node_id)
|
node_name = GremlinStorage._fix_name(source_node_id)
|
||||||
query = f"""g
|
query = f"""g
|
||||||
.E()
|
.E()
|
||||||
@@ -316,7 +305,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((GremlinServerError,)),
|
retry=retry_if_exception_type((GremlinServerError,)),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert a node in the Gremlin graph.
|
Upsert a node in the Gremlin graph.
|
||||||
|
|
||||||
@@ -357,8 +346,8 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
retry=retry_if_exception_type((GremlinServerError,)),
|
retry=retry_if_exception_type((GremlinServerError,)),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their names.
|
Upsert an edge and its properties between two nodes identified by their names.
|
||||||
|
|
||||||
@@ -397,3 +386,17 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
print("Implemented but never called.")
|
print("Implemented but never called.")
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
@@ -12,7 +12,7 @@ if not pm.is_installed("pymongo"):
|
|||||||
if not pm.is_installed("motor"):
|
if not pm.is_installed("motor"):
|
||||||
pm.install("motor")
|
pm.install("motor")
|
||||||
|
|
||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Union
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from pymongo.operations import SearchIndexModel
|
from pymongo.operations import SearchIndexModel
|
||||||
@@ -448,7 +448,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""
|
"""
|
||||||
Return the full node document (including "edges"), or None if missing.
|
Return the full node document (including "edges"), or None if missing.
|
||||||
"""
|
"""
|
||||||
@@ -456,11 +456,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Return the first edge dict from source_node_id to target_node_id if it exists.
|
|
||||||
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
|
|
||||||
"""
|
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{"$match": {"_id": source_node_id}},
|
{"$match": {"_id": source_node_id}},
|
||||||
{
|
{
|
||||||
@@ -486,9 +482,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
return e
|
return e
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
self, source_node_id: str
|
|
||||||
) -> Union[List[Tuple[str, str]], None]:
|
|
||||||
"""
|
"""
|
||||||
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||||
@@ -522,7 +516,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert or update a node document. If new, create an empty edges array.
|
Insert or update a node document. If new, create an empty edges array.
|
||||||
"""
|
"""
|
||||||
@@ -532,8 +526,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
||||||
If an edge with the same target exists, we remove it and re-insert with updated data.
|
If an edge with the same target exists, we remove it and re-insert with updated data.
|
||||||
@@ -559,7 +553,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
1) Remove node's doc entirely.
|
1) Remove node's doc entirely.
|
||||||
2) Remove inbound edges from any doc that references node_id.
|
2) Remove inbound edges from any doc that references node_id.
|
||||||
@@ -576,7 +570,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
"""
|
"""
|
||||||
Placeholder for demonstration, raises NotImplementedError.
|
Placeholder for demonstration, raises NotImplementedError.
|
||||||
"""
|
"""
|
||||||
@@ -606,9 +600,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
labels.append(doc["_id"])
|
labels.append(doc["_id"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
self, node_label: str, max_depth: int = 5
|
|
||||||
) -> KnowledgeGraph:
|
|
||||||
"""
|
"""
|
||||||
Get complete connected subgraph for specified node (including the starting node itself)
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
|
@@ -3,7 +3,8 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, Tuple, List, Dict
|
from typing import Any, List, Dict
|
||||||
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
@@ -191,7 +192,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return single_result["edgeExists"]
|
return single_result["edgeExists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier.
|
"""Get node by its label identifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -252,17 +253,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""Find edge between two nodes identified by their labels.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_id (str): Label of the source node
|
|
||||||
target_node_id (str): Label of the target node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Edge properties if found, with at least {"weight": 0.0}
|
|
||||||
None: If error occurs
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
@@ -321,7 +313,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
# Return default edge properties on error
|
# Return default edge properties on error
|
||||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
node_label = source_node_id.strip('"')
|
node_label = source_node_id.strip('"')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -364,7 +356,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert a node in the Neo4j database.
|
Upsert a node in the Neo4j database.
|
||||||
|
|
||||||
@@ -405,8 +397,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their labels.
|
||||||
|
|
||||||
@@ -444,9 +436,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
print("Implemented but never called.")
|
print("Implemented but never called.")
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
self, node_label: str, max_depth: int = 5
|
|
||||||
) -> KnowledgeGraph:
|
|
||||||
"""
|
"""
|
||||||
Get complete connected subgraph for specified node (including the starting node itself)
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
@@ -603,7 +593,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
await traverse(label, 0)
|
await traverse(label, 0)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_all_labels(self) -> List[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Get all existing node labels in the database
|
Get all existing node labels in the database
|
||||||
Returns:
|
Returns:
|
||||||
@@ -627,3 +617,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async for record in result:
|
async for record in result:
|
||||||
labels.append(record["label"])
|
labels.append(record["label"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
@@ -51,11 +51,12 @@ Usage:
|
|||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, cast
|
from typing import Any, cast
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
@@ -142,7 +143,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
@@ -151,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
return self._graph.has_edge(source_node_id, target_node_id)
|
return self._graph.has_edge(source_node_id, target_node_id)
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
return self._graph.nodes.get(node_id)
|
return self._graph.nodes.get(node_id)
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
@@ -162,35 +163,30 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
return self._graph.edges.get((source_node_id, target_node_id))
|
return self._graph.edges.get((source_node_id, target_node_id))
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str):
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
if self._graph.has_node(source_node_id):
|
if self._graph.has_node(source_node_id):
|
||||||
return list(self._graph.edges(source_node_id))
|
return list(self._graph.edges(source_node_id))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
self._graph.add_node(node_id, **node_data)
|
self._graph.add_node(node_id, **node_data)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
|
||||||
Delete a node from the graph based on the specified node_id.
|
|
||||||
|
|
||||||
:param node_id: The node_id to delete
|
|
||||||
"""
|
|
||||||
if self._graph.has_node(node_id):
|
if self._graph.has_node(node_id):
|
||||||
self._graph.remove_node(node_id)
|
self._graph.remove_node(node_id)
|
||||||
logger.info(f"Node {node_id} deleted from the graph.")
|
logger.info(f"Node {node_id} deleted from the graph.")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
if algorithm not in self._node_embed_algorithms:
|
if algorithm not in self._node_embed_algorithms:
|
||||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||||
return await self._node_embed_algorithms[algorithm]()
|
return await self._node_embed_algorithms[algorithm]()
|
||||||
@@ -226,3 +222,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
for source, target in edges:
|
for source, target in edges:
|
||||||
if self._graph.has_edge(source, target):
|
if self._graph.has_edge(source, target):
|
||||||
self._graph.remove_edge(source, target)
|
self._graph.remove_edge(source, target)
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
@@ -13,6 +13,7 @@ if not pm.is_installed("oracledb"):
|
|||||||
pm.install("oracledb")
|
pm.install("oracledb")
|
||||||
|
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
import oracledb
|
import oracledb
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
@@ -378,9 +379,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
#################### insert method ################
|
#################### insert method ################
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""插入或更新节点"""
|
|
||||||
# print("go into upsert node method")
|
|
||||||
entity_name = node_id
|
entity_name = node_id
|
||||||
entity_type = node_data["entity_type"]
|
entity_type = node_data["entity_type"]
|
||||||
description = node_data["description"]
|
description = node_data["description"]
|
||||||
@@ -413,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""插入或更新边"""
|
"""插入或更新边"""
|
||||||
# print("go into upsert edge method")
|
# print("go into upsert edge method")
|
||||||
source_name = source_node_id
|
source_name = source_node_id
|
||||||
@@ -453,8 +452,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
"""为节点生成向量"""
|
|
||||||
if algorithm not in self._node_embed_algorithms:
|
if algorithm not in self._node_embed_algorithms:
|
||||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||||
return await self._node_embed_algorithms[algorithm]()
|
return await self._node_embed_algorithms[algorithm]()
|
||||||
@@ -471,7 +469,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||||
return embeddings, nodes_ids
|
return embeddings, nodes_ids
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""写入graphhml图文件"""
|
"""写入graphhml图文件"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
||||||
@@ -493,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""根据源和目标节点id检查边是否存在"""
|
|
||||||
SQL = SQL_TEMPLATES["has_edge"]
|
SQL = SQL_TEMPLATES["has_edge"]
|
||||||
params = {
|
params = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
@@ -510,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""根据节点id获取节点的度"""
|
|
||||||
SQL = SQL_TEMPLATES["node_degree"]
|
SQL = SQL_TEMPLATES["node_degree"]
|
||||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
@@ -528,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
# print("Edge degree",degree)
|
# print("Edge degree",degree)
|
||||||
return degree
|
return degree
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""根据节点id获取节点数据"""
|
"""根据节点id获取节点数据"""
|
||||||
SQL = SQL_TEMPLATES["get_node"]
|
SQL = SQL_TEMPLATES["get_node"]
|
||||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
@@ -544,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""根据源和目标节点id获取边"""
|
|
||||||
SQL = SQL_TEMPLATES["get_edge"]
|
SQL = SQL_TEMPLATES["get_edge"]
|
||||||
params = {
|
params = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
@@ -560,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str):
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""根据节点id获取节点的所有边"""
|
|
||||||
if await self.has_node(source_node_id):
|
if await self.has_node(source_node_id):
|
||||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||||
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||||
@@ -597,6 +591,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
if res:
|
if res:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
@@ -4,11 +4,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
|
|
||||||
if not pm.is_installed("asyncpg"):
|
if not pm.is_installed("asyncpg"):
|
||||||
pm.install("asyncpg")
|
pm.install("asyncpg")
|
||||||
|
|
||||||
@@ -835,7 +837,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return single_result["edge_exists"]
|
return single_result["edge_exists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:Entity {node_id: "%s"})
|
MATCH (n:Entity {node_id: "%s"})
|
||||||
@@ -890,17 +892,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_node_id (str): Label of the source nodes
|
|
||||||
target_node_id (str): Label of the target nodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of all relationships/edges found
|
|
||||||
"""
|
|
||||||
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||||
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
||||||
|
|
||||||
@@ -924,7 +916,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
"""
|
||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
:return: List of dictionaries containing edge information
|
:return: List of dictionaries containing edge information
|
||||||
@@ -972,14 +964,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
|
||||||
Upsert a node in the AGE database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The unique identifier for the node (used as label)
|
|
||||||
node_data: Dictionary of node properties
|
|
||||||
"""
|
|
||||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||||
properties = node_data
|
properties = node_data
|
||||||
|
|
||||||
@@ -1010,8 +995,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their labels.
|
||||||
|
|
||||||
@@ -1053,6 +1038,19 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
print("Implemented but never called.")
|
print("Implemented but never called.")
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_nodes(
|
||||||
|
self, algorithm: str
|
||||||
|
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
NAMESPACE_TABLE_MAP = {
|
NAMESPACE_TABLE_MAP = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
@@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"):
|
|||||||
if not pm.is_installed("sqlalchemy"):
|
if not pm.is_installed("sqlalchemy"):
|
||||||
pm.install("sqlalchemy")
|
pm.install("sqlalchemy")
|
||||||
|
|
||||||
|
from lightrag.types import KnowledgeGraph
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
#################### upsert method ################
|
#################### upsert method ################
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
entity_name = node_id
|
entity_name = node_id
|
||||||
entity_type = node_data["entity_type"]
|
entity_type = node_data["entity_type"]
|
||||||
description = node_data["description"]
|
description = node_data["description"]
|
||||||
@@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
):
|
) -> None:
|
||||||
source_name = source_node_id
|
source_name = source_node_id
|
||||||
target_name = target_node_id
|
target_name = target_node_id
|
||||||
weight = edge_data["weight"]
|
weight = edge_data["weight"]
|
||||||
@@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
}
|
}
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||||
if algorithm not in self._node_embed_algorithms:
|
if algorithm not in self._node_embed_algorithms:
|
||||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||||
return await self._node_embed_algorithms[algorithm]()
|
return await self._node_embed_algorithms[algorithm]()
|
||||||
@@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||||
return degree
|
return degree
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
sql = SQL_TEMPLATES["get_node"]
|
sql = SQL_TEMPLATES["get_node"]
|
||||||
param = {"name": node_id, "workspace": self.db.workspace}
|
param = {"name": node_id, "workspace": self.db.workspace}
|
||||||
return await self.db.query(sql, param)
|
return await self.db.query(sql, param)
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> dict[str, str] | None:
|
||||||
sql = SQL_TEMPLATES["get_edge"]
|
sql = SQL_TEMPLATES["get_edge"]
|
||||||
param = {
|
param = {
|
||||||
"source_name": source_node_id,
|
"source_name": source_node_id,
|
||||||
@@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
}
|
}
|
||||||
return await self.db.query(sql, param)
|
return await self.db.query(sql, param)
|
||||||
|
|
||||||
async def get_node_edges(
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
self, source_node_id: str
|
|
||||||
) -> Union[list[tuple[str, str]], None]:
|
|
||||||
sql = SQL_TEMPLATES["get_node_edges"]
|
sql = SQL_TEMPLATES["get_node_edges"]
|
||||||
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
||||||
res = await self.db.query(sql, param, multirows=True)
|
res = await self.db.query(sql, param, multirows=True)
|
||||||
@@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def delete_node(self, node_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
|
Reference in New Issue
Block a user