added final, required methods and cleaned import

This commit is contained in:
Yannick Stephan
2025-02-16 14:38:09 +01:00
parent 7848a38a45
commit 3fef8201c6
16 changed files with 209 additions and 316 deletions

View File

@@ -92,6 +92,7 @@ class StorageNameSpace(ABC):
@dataclass
class BaseVectorStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc
cosine_better_than_threshold: float
meta_fields: set[str] = field(default_factory=set)
@abstractmethod

View File

@@ -5,21 +5,11 @@ import os
import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import numpy as np
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
from lightrag.types import KnowledgeGraph
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
from tenacity import (
retry,
retry_if_exception_type,
@@ -37,6 +27,16 @@ if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
try:
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
except ImportError as e:
raise ImportError(
"psycopg-pool, psycopg[binary,pool], asyncpg library is not installed. Please install it to proceed."
) from e
class AGEQueryException(Exception):
"""Exception for the AGE queries."""
@@ -55,6 +55,7 @@ class AGEQueryException(Exception):
return self.details
@final
@dataclass
class AGEStorage(BaseGraphStorage):
@staticmethod
@@ -100,9 +101,6 @@ class AGEStorage(BaseGraphStorage):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print("KG successfully indexed.")
@staticmethod
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
"""
@@ -627,3 +625,6 @@ class AGEStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
async def index_done_callback(self) -> None:
pass

View File

@@ -1,19 +1,25 @@
import asyncio
from dataclasses import dataclass
from typing import Any
from typing import Any, final
import numpy as np
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
try:
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
except ImportError as e:
raise ImportError(
"chromadb library is not installed. Please install it to proceed."
) from e
@final
@dataclass
class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation."""
cosine_better_than_threshold: float = None
def __post_init__(self):
try:
config = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -1,8 +1,8 @@
import os
import time
import asyncio
from typing import Any
import faiss
from typing import Any, final
import json
import numpy as np
from tqdm.asyncio import tqdm as tqdm_async
@@ -16,7 +16,15 @@ from lightrag.base import (
BaseVectorStorage,
)
try:
import faiss
except ImportError as e:
raise ImportError(
"faiss library is not installed. Please install it to proceed."
) from e
@final
@dataclass
class FaissVectorDBStorage(BaseVectorStorage):
"""
@@ -24,8 +32,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
"""
cosine_better_than_threshold: float = None
def __post_init__(self):
# Grab config values if available
config = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -3,13 +3,11 @@ import inspect
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any, Dict, List, final
import numpy as np
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
from tenacity import (
retry,
retry_if_exception_type,
@@ -22,7 +20,17 @@ from lightrag.utils import logger
from ..base import BaseGraphStorage
try:
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
except ImportError as e:
raise ImportError(
"gremlin library is not installed. Please install it to proceed."
) from e
@final
@dataclass
class GremlinStorage(BaseGraphStorage):
@staticmethod
@@ -79,8 +87,8 @@ class GremlinStorage(BaseGraphStorage):
if self._driver:
self._driver.close()
async def index_done_callback(self):
print("KG successfully indexed.")
async def index_done_callback(self) -> None:
pass
@staticmethod
def _to_value_map(value: Any) -> str:

View File

@@ -1,56 +1,6 @@
"""
JsonDocStatus Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
from dataclasses import dataclass
import os
from typing import Any, Union
from typing import Any, Union, final
from lightrag.base import (
DocProcessingStatus,
@@ -64,6 +14,7 @@ from lightrag.utils import (
)
@final
@dataclass
class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage"""
@@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage):
self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: set[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
return set(data) - set(self._data.keys())
return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
@@ -94,7 +45,6 @@ class JsonDocStatusStorage(DocStatusStorage):
return counts
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
@@ -102,7 +52,6 @@ class JsonDocStatusStorage(DocStatusStorage):
}
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
@@ -110,7 +59,6 @@ class JsonDocStatusStorage(DocStatusStorage):
}
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
@@ -118,23 +66,16 @@ class JsonDocStatusStorage(DocStatusStorage):
}
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING
}
async def index_done_callback(self):
"""Save data to file after indexing"""
async def index_done_callback(self) -> None:
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, Any]) -> None:
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data)
await self.index_done_callback()
@@ -142,7 +83,12 @@ class JsonDocStatusStorage(DocStatusStorage):
return self._data.get(id)
async def delete(self, doc_ids: list[str]):
"""Delete document status by IDs"""
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await self.index_done_callback()
async def drop(self) -> None:
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
raise NotImplementedError

View File

@@ -22,7 +22,7 @@ class JsonKVStorage(BaseKVStorage):
self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def index_done_callback(self):
async def index_done_callback(self) -> None:
write_json(self._data, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None:

View File

@@ -1,27 +1,29 @@
import asyncio
import os
from typing import Any
from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
from lightrag.utils import logger
from ..base import BaseVectorStorage
import pipmaster as pm
import configparser
if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
try:
from pymilvus import MilvusClient
except ImportError:
raise ImportError(
"pymilvus library is not installed. Please install it to proceed."
)
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@final
@dataclass
class MilvusVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod
def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs

View File

@@ -1,22 +1,11 @@
import os
from dataclasses import dataclass
import numpy as np
import pipmaster as pm
import configparser
from tqdm.asyncio import tqdm as tqdm_async
import asyncio
if not pm.is_installed("pymongo"):
pm.install("pymongo")
if not pm.is_installed("motor"):
pm.install("motor")
from typing import Any, List, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
from typing import Any, List, Union, final
from ..base import (
BaseGraphStorage,
@@ -30,11 +19,22 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
raise ImportError(
"motor, pymongo library is not installed. Please install it to proceed."
) from e
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
@@ -115,6 +115,7 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop()
@final
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
@@ -210,7 +211,15 @@ class MongoDocStatusStorage(DocStatusStorage):
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self) -> None:
# Implement the method here
pass
async def update_doc_status(self, data: dict[str, Any]) -> None:
raise NotImplementedError
@final
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""
@@ -774,11 +783,13 @@ class MongoGraphStorage(BaseGraphStorage):
return result
async def index_done_callback(self) -> None:
pass
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")

View File

@@ -1,65 +1,10 @@
"""
NanoVectorDB Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import asyncio
import os
from typing import Any
from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
import pipmaster as pm
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
import time
from lightrag.utils import (
@@ -71,11 +16,17 @@ from lightrag.base import (
BaseVectorStorage,
)
try:
from nano_vectordb import NanoVectorDB
except ImportError as e:
raise ImportError(
"nano-vectordb library is not installed. Please install it to proceed."
) from e
@final
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()

View File

@@ -3,21 +3,11 @@ import inspect
import os
import re
from dataclasses import dataclass
from typing import Any, List, Dict
from typing import Any, List, Dict, final
import numpy as np
import pipmaster as pm
import configparser
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
from tenacity import (
retry,
stop_after_attempt,
@@ -29,11 +19,25 @@ from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
try:
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
except ImportError as e:
raise ImportError(
"neo4j library is not installed. Please install it to proceed."
) from e
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@final
@dataclass
class Neo4JStorage(BaseGraphStorage):
@staticmethod
@@ -141,8 +145,8 @@ class Neo4JStorage(BaseGraphStorage):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print("KG successfully indexed.")
async def index_done_callback(self) -> None:
pass
async def _label_exists(self, label: str) -> bool:
"""Check if a label exists in the Neo4j database."""

View File

@@ -1,58 +1,8 @@
"""
NetworkX Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import html
import os
from dataclasses import dataclass
from typing import Any, cast
import networkx as nx
from typing import Any, cast, final
import numpy as np
@@ -65,7 +15,15 @@ from lightrag.base import (
BaseGraphStorage,
)
try:
import networkx as nx
except ImportError as e:
raise ImportError(
"networkx library is not installed. Please install it to proceed."
) from e
@final
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod

View File

@@ -4,17 +4,11 @@ import asyncio
# import html
# import os
from dataclasses import dataclass
from typing import Any, Union
from typing import Any, Union, final
import numpy as np
import pipmaster as pm
if not pm.is_installed("oracledb"):
pm.install("oracledb")
from lightrag.types import KnowledgeGraph
import oracledb
from ..base import (
BaseGraphStorage,
@@ -24,6 +18,14 @@ from ..base import (
from ..namespace import NameSpace, is_namespace
from ..utils import logger
try:
import oracledb
except ImportError as e:
raise ImportError(
"oracledb library is not installed. Please install it to proceed."
) from e
class OracleDB:
def __init__(self, config, **kwargs):
@@ -170,6 +172,7 @@ class OracleDB:
raise
@final
@dataclass
class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use
@@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage):
raise NotImplementedError
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: OracleDB
cosine_better_than_threshold: float = None
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pass
async def index_done_callback(self):
async def index_done_callback(self) -> None:
pass
#################### query method ###############
@@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage):
raise NotImplementedError
@final
@dataclass
class OracleGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: OracleDB
def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################
@@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage):
return embeddings, nodes_ids
async def index_done_callback(self) -> None:
"""写入graphhml图文件"""
logger.info(
"Node and edge data had been saved into oracle db already, so nothing to do here!"
)
pass
#################### query method #################
async def has_node(self, node_id: str) -> bool:

View File

@@ -4,26 +4,19 @@ import json
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, final
import numpy as np
import pipmaster as pm
from lightrag.types import KnowledgeGraph
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import sys
import asyncpg
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from tqdm.asyncio import tqdm as tqdm_async
from ..base import (
BaseGraphStorage,
@@ -41,6 +34,15 @@ if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
try:
import asyncpg
from tqdm.asyncio import tqdm as tqdm_async
except ImportError as e:
raise ImportError(
"asyncpg, tqdm_async library is not installed. Please install it to proceed."
) from e
class PostgreSQLDB:
def __init__(self, config, **kwargs):
@@ -177,6 +179,7 @@ class PostgreSQLDB:
pass
@final
@dataclass
class PGKVStorage(BaseKVStorage):
# db instance must be injected before use
@@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self) -> None:
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into postgresql db!")
pass
async def drop(self) -> None:
raise NotImplementedError
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = None
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data)
async def index_done_callback(self) -> None:
logger.info("vector data had been saved into postgresql db!")
pass
#################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
@@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage):
raise NotImplementedError
@final
@dataclass
class PGDocStatusStorage(DocStatusStorage):
# db instance must be injected before use
# db: PostgreSQLDB
async def filter_keys(self, data: set[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that don't exist in storage"""
keys = ",".join([f"'{_id}'" for _id in data])
keys = ",".join([f"'{_id}'" for _id in keys])
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
result = await self.db.query(sql, multirows=True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None:
return set(data)
return set(keys)
else:
existed = set([element["id"] for element in result])
return set(data) - existed
return set(keys) - existed
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
@@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage):
updated_at=result[0]["updated_at"],
)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
raise NotImplementedError
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count"
@@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!")
async def index_done_callback(self) -> None:
pass
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
@@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage):
}
await self.db.execute(sql, _data)
async def drop(self) -> None:
raise NotImplementedError
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""
@@ -593,11 +595,9 @@ class PGGraphQueryException(Exception):
return self.details
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: PostgreSQLDB
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
@@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
print("KG successfully indexed.")
async def index_done_callback(self) -> None:
pass
@staticmethod
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:

View File

@@ -1,6 +1,6 @@
import asyncio
import os
from typing import Any
from typing import Any, final
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
@@ -8,17 +8,20 @@ import hashlib
import uuid
from ..utils import logger
from ..base import BaseVectorStorage
import pipmaster as pm
import configparser
if not pm.is_installed("qdrant_client"):
pm.install("qdrant_client")
from qdrant_client import QdrantClient, models
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
try:
from qdrant_client import QdrantClient, models
except ImportError as e:
raise ImportError(
"qdrant_client library is not installed. Please install it to proceed."
) from e
def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple"
@@ -48,10 +51,9 @@ def compute_mdhash_id_for_qdrant(
raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.")
@final
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod
def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs

View File

@@ -1,24 +1,26 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Union
from typing import Any, Union, final
import numpy as np
import pipmaster as pm
if not pm.is_installed("pymysql"):
pm.install("pymysql")
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from lightrag.types import KnowledgeGraph
from sqlalchemy import create_engine, text
from tqdm import tqdm
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace
from ..utils import logger
try:
from sqlalchemy import create_engine, text
except ImportError as e:
raise ImportError(
"pymysql, sqlalchemy library is not installed. Please install it to proceed."
) from e
class TiDB:
def __init__(self, config, **kwargs):
@@ -100,6 +102,7 @@ class TiDB:
raise
@final
@dataclass
class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use
@@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage):
await self.db.execute(merge_sql, data)
return left_data
async def index_done_callback(self):
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into TiDB db!")
async def index_done_callback(self) -> None:
pass
async def drop(self) -> None:
raise NotImplementedError
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: TiDB
cosine_better_than_threshold: float = None
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
async def index_done_callback(self) -> None:
raise NotImplementedError
@final
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
@@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage):
else:
return []
async def index_done_callback(self) -> None:
pass
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError