added final, required methods and cleaned import
This commit is contained in:
@@ -4,26 +4,19 @@ import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import sys
|
||||
|
||||
import asyncpg
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -41,6 +34,15 @@ if sys.platform.startswith("win"):
|
||||
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
try:
|
||||
import asyncpg
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"asyncpg, tqdm_async library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
class PostgreSQLDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -177,6 +179,7 @@ class PostgreSQLDB:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage):
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into postgresql db!")
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
@@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
logger.info("vector data had been saved into postgresql db!")
|
||||
pass
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
@@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
keys = ",".join([f"'{_id}'" for _id in data])
|
||||
keys = ",".join([f"'{_id}'" for _id in keys])
|
||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||
result = await self.db.query(sql, multirows=True)
|
||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||
if result is None:
|
||||
return set(data)
|
||||
return set(keys)
|
||||
else:
|
||||
existed = set([element["id"] for element in result])
|
||||
return set(data) - existed
|
||||
return set(keys) - existed
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||
@@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
updated_at=result[0]["updated_at"],
|
||||
)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||
@@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
"""Get all procesed documents"""
|
||||
return await self.get_docs_by_status(DocStatus.PROCESSED)
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
|
||||
logger.info("Doc status had been saved into postgresql db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""Update or insert document status
|
||||
@@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
}
|
||||
await self.db.execute(sql, _data)
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PGGraphQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
@@ -593,11 +595,9 @@ class PGGraphQueryException(Exception):
|
||||
return self.details
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
@@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
||||
|
Reference in New Issue
Block a user