added final, required methods and cleaned import

This commit is contained in:
Yannick Stephan
2025-02-16 14:38:09 +01:00
parent 7848a38a45
commit 3fef8201c6
16 changed files with 209 additions and 316 deletions

View File

@@ -92,6 +92,7 @@ class StorageNameSpace(ABC):
@dataclass @dataclass
class BaseVectorStorage(StorageNameSpace, ABC): class BaseVectorStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
cosine_better_than_threshold: float
meta_fields: set[str] = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
@abstractmethod @abstractmethod

View File

@@ -5,21 +5,11 @@ import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import numpy as np import numpy as np
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
from tenacity import ( from tenacity import (
retry, retry,
retry_if_exception_type, retry_if_exception_type,
@@ -37,6 +27,16 @@ if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
try:
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
except ImportError as e:
raise ImportError(
"psycopg-pool, psycopg[binary,pool], asyncpg library is not installed. Please install it to proceed."
) from e
class AGEQueryException(Exception): class AGEQueryException(Exception):
"""Exception for the AGE queries.""" """Exception for the AGE queries."""
@@ -55,6 +55,7 @@ class AGEQueryException(Exception):
return self.details return self.details
@final
@dataclass @dataclass
class AGEStorage(BaseGraphStorage): class AGEStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -100,9 +101,6 @@ class AGEStorage(BaseGraphStorage):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
async def index_done_callback(self):
print("KG successfully indexed.")
@staticmethod @staticmethod
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
""" """
@@ -627,3 +625,6 @@ class AGEStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:
raise NotImplementedError raise NotImplementedError
async def index_done_callback(self) -> None:
pass

View File

@@ -1,19 +1,25 @@
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any, final
import numpy as np import numpy as np
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from lightrag.utils import logger from lightrag.utils import logger
try:
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
except ImportError as e:
raise ImportError(
"chromadb library is not installed. Please install it to proceed."
) from e
@final
@dataclass @dataclass
class ChromaVectorDBStorage(BaseVectorStorage): class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation.""" """ChromaDB vector storage implementation."""
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
try: try:
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -1,8 +1,8 @@
import os import os
import time import time
import asyncio import asyncio
from typing import Any from typing import Any, final
import faiss
import json import json
import numpy as np import numpy as np
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
@@ -16,7 +16,15 @@ from lightrag.base import (
BaseVectorStorage, BaseVectorStorage,
) )
try:
import faiss
except ImportError as e:
raise ImportError(
"faiss library is not installed. Please install it to proceed."
) from e
@final
@dataclass @dataclass
class FaissVectorDBStorage(BaseVectorStorage): class FaissVectorDBStorage(BaseVectorStorage):
""" """
@@ -24,8 +32,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
""" """
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Grab config values if available # Grab config values if available
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -3,13 +3,11 @@ import inspect
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List from typing import Any, Dict, List, final
import numpy as np import numpy as np
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
from tenacity import ( from tenacity import (
retry, retry,
retry_if_exception_type, retry_if_exception_type,
@@ -22,7 +20,17 @@ from lightrag.utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
try:
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
except ImportError as e:
raise ImportError(
"gremlin library is not installed. Please install it to proceed."
) from e
@final
@dataclass @dataclass
class GremlinStorage(BaseGraphStorage): class GremlinStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -79,8 +87,8 @@ class GremlinStorage(BaseGraphStorage):
if self._driver: if self._driver:
self._driver.close() self._driver.close()
async def index_done_callback(self): async def index_done_callback(self) -> None:
print("KG successfully indexed.") pass
@staticmethod @staticmethod
def _to_value_map(value: Any) -> str: def _to_value_map(value: Any) -> str:

View File

@@ -1,56 +1,6 @@
"""
JsonDocStatus Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
from dataclasses import dataclass from dataclasses import dataclass
import os import os
from typing import Any, Union from typing import Any, Union, final
from lightrag.base import ( from lightrag.base import (
DocProcessingStatus, DocProcessingStatus,
@@ -64,6 +14,7 @@ from lightrag.utils import (
) )
@final
@dataclass @dataclass
class JsonDocStatusStorage(DocStatusStorage): class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage""" """JSON implementation of document status storage"""
@@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage):
self._data: dict[str, Any] = load_json(self._file_name) or {} self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records") logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)""" """Return keys that should be processed (not in storage or not successfully processed)"""
return set(data) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] result: list[dict[str, Any]] = []
@@ -94,7 +45,6 @@ class JsonDocStatusStorage(DocStatusStorage):
return counts return counts
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return { return {
k: DocProcessingStatus(**v) k: DocProcessingStatus(**v)
for k, v in self._data.items() for k, v in self._data.items()
@@ -102,7 +52,6 @@ class JsonDocStatusStorage(DocStatusStorage):
} }
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return { return {
k: DocProcessingStatus(**v) k: DocProcessingStatus(**v)
for k, v in self._data.items() for k, v in self._data.items()
@@ -110,7 +59,6 @@ class JsonDocStatusStorage(DocStatusStorage):
} }
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return { return {
k: DocProcessingStatus(**v) k: DocProcessingStatus(**v)
for k, v in self._data.items() for k, v in self._data.items()
@@ -118,23 +66,16 @@ class JsonDocStatusStorage(DocStatusStorage):
} }
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return { return {
k: DocProcessingStatus(**v) k: DocProcessingStatus(**v)
for k, v in self._data.items() for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING if v["status"] == DocStatus.PROCESSING
} }
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""Save data to file after indexing"""
write_json(self._data, self._file_name) write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, Any]) -> None: async def upsert(self, data: dict[str, Any]) -> None:
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data) self._data.update(data)
await self.index_done_callback() await self.index_done_callback()
@@ -142,7 +83,12 @@ class JsonDocStatusStorage(DocStatusStorage):
return self._data.get(id) return self._data.get(id)
async def delete(self, doc_ids: list[str]): async def delete(self, doc_ids: list[str]):
"""Delete document status by IDs"""
for doc_id in doc_ids: for doc_id in doc_ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()
async def drop(self) -> None:
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
raise NotImplementedError

View File

@@ -22,7 +22,7 @@ class JsonKVStorage(BaseKVStorage):
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def index_done_callback(self): async def index_done_callback(self) -> None:
write_json(self._data, self._file_name) write_json(self._data, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:

View File

@@ -1,27 +1,29 @@
import asyncio import asyncio
import os import os
from typing import Any from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm
import configparser import configparser
if not pm.is_installed("pymilvus"): try:
pm.install("pymilvus") from pymilvus import MilvusClient
from pymilvus import MilvusClient except ImportError:
raise ImportError(
"pymilvus library is not installed. Please install it to proceed."
)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@final
@dataclass @dataclass
class MilvusVectorDBStorage(BaseVectorStorage): class MilvusVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs client: MilvusClient, collection_name: str, **kwargs

View File

@@ -1,22 +1,11 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import pipmaster as pm
import configparser import configparser
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
import asyncio import asyncio
if not pm.is_installed("pymongo"): from typing import Any, List, Union, final
pm.install("pymongo")
if not pm.is_installed("motor"):
pm.install("motor")
from typing import Any, List, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -30,11 +19,22 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
raise ImportError(
"motor, pymongo library is not installed. Please install it to proceed."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
@@ -115,6 +115,7 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop() await self._data.drop()
@final
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self): def __post_init__(self):
@@ -210,7 +211,15 @@ class MongoDocStatusStorage(DocStatusStorage):
"""Get all procesed documents""" """Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED) return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self) -> None:
# Implement the method here
pass
async def update_doc_status(self, data: dict[str, Any]) -> None:
raise NotImplementedError
@final
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):
""" """
@@ -774,11 +783,13 @@ class MongoGraphStorage(BaseGraphStorage):
return result return result
async def index_done_callback(self) -> None:
pass
@final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")

View File

@@ -1,65 +1,10 @@
"""
NanoVectorDB Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import asyncio import asyncio
import os import os
from typing import Any from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import pipmaster as pm
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
import time import time
from lightrag.utils import ( from lightrag.utils import (
@@ -71,11 +16,17 @@ from lightrag.base import (
BaseVectorStorage, BaseVectorStorage,
) )
try:
from nano_vectordb import NanoVectorDB
except ImportError as e:
raise ImportError(
"nano-vectordb library is not installed. Please install it to proceed."
) from e
@final
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize lock only for file operations
self._save_lock = asyncio.Lock() self._save_lock = asyncio.Lock()

View File

@@ -3,21 +3,11 @@ import inspect
import os import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Dict from typing import Any, List, Dict, final
import numpy as np import numpy as np
import pipmaster as pm
import configparser import configparser
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@@ -29,11 +19,25 @@ from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
try:
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
except ImportError as e:
raise ImportError(
"neo4j library is not installed. Please install it to proceed."
) from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@final
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -141,8 +145,8 @@ class Neo4JStorage(BaseGraphStorage):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
async def index_done_callback(self): async def index_done_callback(self) -> None:
print("KG successfully indexed.") pass
async def _label_exists(self, label: str) -> bool: async def _label_exists(self, label: str) -> bool:
"""Check if a label exists in the Neo4j database.""" """Check if a label exists in the Neo4j database."""

View File

@@ -1,58 +1,8 @@
"""
NetworkX Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import html import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, cast from typing import Any, cast, final
import networkx as nx
import numpy as np import numpy as np
@@ -65,7 +15,15 @@ from lightrag.base import (
BaseGraphStorage, BaseGraphStorage,
) )
try:
import networkx as nx
except ImportError as e:
raise ImportError(
"networkx library is not installed. Please install it to proceed."
) from e
@final
@dataclass @dataclass
class NetworkXStorage(BaseGraphStorage): class NetworkXStorage(BaseGraphStorage):
@staticmethod @staticmethod

View File

@@ -4,17 +4,11 @@ import asyncio
# import html # import html
# import os # import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any, Union, final
import numpy as np import numpy as np
import pipmaster as pm
if not pm.is_installed("oracledb"):
pm.install("oracledb")
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
import oracledb
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -24,6 +18,14 @@ from ..base import (
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
try:
import oracledb
except ImportError as e:
raise ImportError(
"oracledb library is not installed. Please install it to proceed."
) from e
class OracleDB: class OracleDB:
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
@@ -170,6 +172,7 @@ class OracleDB:
raise raise
@final
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage):
raise NotImplementedError raise NotImplementedError
@final
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: OracleDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
@@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pass pass
async def index_done_callback(self): async def index_done_callback(self) -> None:
pass pass
#################### query method ############### #################### query method ###############
@@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage):
raise NotImplementedError raise NotImplementedError
@final
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: OracleDB
def __post_init__(self): def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################ #################### insert method ################
@@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage):
return embeddings, nodes_ids return embeddings, nodes_ids
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
"""写入graphhml图文件""" pass
logger.info(
"Node and edge data had been saved into oracle db already, so nothing to do here!"
)
#################### query method ################# #################### query method #################
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:

View File

@@ -4,26 +4,19 @@ import json
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union, final
import numpy as np import numpy as np
import pipmaster as pm
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import sys import sys
import asyncpg
from tenacity import ( from tenacity import (
retry, retry,
retry_if_exception_type, retry_if_exception_type,
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
) )
from tqdm.asyncio import tqdm as tqdm_async
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -41,6 +34,15 @@ if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
try:
import asyncpg
from tqdm.asyncio import tqdm as tqdm_async
except ImportError as e:
raise ImportError(
"asyncpg, tqdm_async library is not installed. Please install it to proceed."
) from e
class PostgreSQLDB: class PostgreSQLDB:
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
@@ -177,6 +179,7 @@ class PostgreSQLDB:
pass pass
@final
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
if is_namespace( pass
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into postgresql db!")
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError raise NotImplementedError
@final
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
logger.info("vector data had been saved into postgresql db!") pass
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
@@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage):
raise NotImplementedError raise NotImplementedError
@final
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
# db instance must be injected before use # db instance must be injected before use
# db: PostgreSQLDB # db: PostgreSQLDB
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
keys = ",".join([f"'{_id}'" for _id in data]) keys = ",".join([f"'{_id}'" for _id in keys])
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
result = await self.db.query(sql, multirows=True) result = await self.db.query(sql, multirows=True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None: if result is None:
return set(data) return set(keys)
else: else:
existed = set([element["id"] for element in result]) existed = set([element["id"] for element in result])
return set(data) - existed return set(keys) - existed
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
@@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage):
updated_at=result[0]["updated_at"], updated_at=result[0]["updated_at"],
) )
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
raise NotImplementedError
async def get_status_counts(self) -> Dict[str, int]: async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count" sql = """SELECT status as "status", COUNT(1) as "count"
@@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all procesed documents""" """Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED) return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" pass
logger.info("Doc status had been saved into postgresql db!")
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
"""Update or insert document status """Update or insert document status
@@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage):
} }
await self.db.execute(sql, _data) await self.db.execute(sql, _data)
async def drop(self) -> None:
raise NotImplementedError
class PGGraphQueryException(Exception): class PGGraphQueryException(Exception):
"""Exception for the AGE queries.""" """Exception for the AGE queries."""
@@ -593,11 +595,9 @@ class PGGraphQueryException(Exception):
return self.details return self.details
@final
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: PostgreSQLDB
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print("no preloading of graph with AGE in production") print("no preloading of graph with AGE in production")
@@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def index_done_callback(self): async def index_done_callback(self) -> None:
print("KG successfully indexed.") pass
@staticmethod @staticmethod
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
import os import os
from typing import Any from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@@ -8,17 +8,20 @@ import hashlib
import uuid import uuid
from ..utils import logger from ..utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm
import configparser import configparser
if not pm.is_installed("qdrant_client"):
pm.install("qdrant_client")
from qdrant_client import QdrantClient, models
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
try:
from qdrant_client import QdrantClient, models
except ImportError as e:
raise ImportError(
"qdrant_client library is not installed. Please install it to proceed."
) from e
def compute_mdhash_id_for_qdrant( def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple" content: str, prefix: str = "", style: str = "simple"
@@ -48,10 +51,9 @@ def compute_mdhash_id_for_qdrant(
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.") raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
@final
@dataclass @dataclass
class QdrantVectorDBStorage(BaseVectorStorage): class QdrantVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs client: QdrantClient, collection_name: str, **kwargs

View File

@@ -1,24 +1,26 @@
import asyncio import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any, Union, final
import numpy as np import numpy as np
import pipmaster as pm
if not pm.is_installed("pymysql"):
pm.install("pymysql")
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
from sqlalchemy import create_engine, text
from tqdm import tqdm from tqdm import tqdm
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
try:
from sqlalchemy import create_engine, text
except ImportError as e:
raise ImportError(
"pymysql, sqlalchemy library is not installed. Please install it to proceed."
) from e
class TiDB: class TiDB:
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
@@ -100,6 +102,7 @@ class TiDB:
raise raise
@final
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage):
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
return left_data return left_data
async def index_done_callback(self): async def index_done_callback(self) -> None:
if is_namespace( pass
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into TiDB db!")
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError raise NotImplementedError
@final
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: TiDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError
async def index_done_callback(self) -> None:
raise NotImplementedError
@final
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage):
else: else:
return [] return []
async def index_done_callback(self) -> None:
pass
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError