added final, required methods and cleaned import

This commit is contained in:
Yannick Stephan
2025-02-16 14:38:09 +01:00
parent 7848a38a45
commit 3fef8201c6
16 changed files with 209 additions and 316 deletions

View File

@@ -4,26 +4,19 @@ import json
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, final
import numpy as np
import pipmaster as pm
from lightrag.types import KnowledgeGraph
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import sys
import asyncpg
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from tqdm.asyncio import tqdm as tqdm_async
from ..base import (
BaseGraphStorage,
@@ -41,6 +34,15 @@ if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
try:
import asyncpg
from tqdm.asyncio import tqdm as tqdm_async
except ImportError as e:
raise ImportError(
"asyncpg, tqdm_async library is not installed. Please install it to proceed."
) from e
class PostgreSQLDB:
def __init__(self, config, **kwargs):
@@ -177,6 +179,7 @@ class PostgreSQLDB:
pass
@final
@dataclass
class PGKVStorage(BaseKVStorage):
# db instance must be injected before use
@@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self) -> None:
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into postgresql db!")
pass
async def drop(self) -> None:
raise NotImplementedError
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = None
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data)
async def index_done_callback(self) -> None:
logger.info("vector data had been saved into postgresql db!")
pass
#################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
@@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage):
raise NotImplementedError
@final
@dataclass
class PGDocStatusStorage(DocStatusStorage):
# db instance must be injected before use
# db: PostgreSQLDB
async def filter_keys(self, data: set[str]) -> set[str]:
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that don't exist in storage"""
keys = ",".join([f"'{_id}'" for _id in data])
keys = ",".join([f"'{_id}'" for _id in keys])
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
result = await self.db.query(sql, multirows=True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None:
return set(data)
return set(keys)
else:
existed = set([element["id"] for element in result])
return set(data) - existed
return set(keys) - existed
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
@@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage):
updated_at=result[0]["updated_at"],
)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
raise NotImplementedError
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count"
@@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!")
async def index_done_callback(self) -> None:
pass
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
@@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage):
}
await self.db.execute(sql, _data)
async def drop(self) -> None:
raise NotImplementedError
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""
@@ -593,11 +595,9 @@ class PGGraphQueryException(Exception):
return self.details
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: PostgreSQLDB
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
@@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
print("KG successfully indexed.")
async def index_done_callback(self) -> None:
pass
@staticmethod
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: