Merge remote-tracking branch 'origin/main' into graph-viewer-webui

This commit is contained in:
ArnoChen
2025-02-11 22:52:01 +08:00
4 changed files with 38 additions and 7 deletions

View File

@@ -226,3 +226,7 @@ class DocStatusStorage(BaseKVStorage):
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
"""Updates the status of a document. By default, it calls upsert."""
await self.upsert(data)

View File

@@ -471,7 +471,7 @@ class PGDocStatusStorage(DocStatusStorage):
self, status: DocStatus
) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status}
result = await self.db.query(sql, params, True)
return {
@@ -505,8 +505,8 @@ class PGDocStatusStorage(DocStatusStorage):
Args:
data: Dictionary of document IDs and their status data
"""
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6)
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6,$7)
on conflict(id,workspace) do update set
content = EXCLUDED.content,
content_summary = EXCLUDED.content_summary,
@@ -530,6 +530,32 @@ class PGDocStatusStorage(DocStatusStorage):
)
return data
async def update_doc_status(self, data: dict[str, dict]) -> None:
"""
Updates only the document status, chunk count, and updated timestamp.
This method ensures that only relevant fields are updated instead of overwriting
the entire document record. If `updated_at` is not provided, the database will
automatically use the current timestamp.
"""
sql = """
UPDATE LIGHTRAG_DOC_STATUS
SET status = $3,
chunks_count = $4,
updated_at = CURRENT_TIMESTAMP
WHERE workspace = $1 AND id = $2
"""
for k, v in data.items():
_data = {
"workspace": self.db.workspace,
"id": k,
"status": v["status"].value, # Convert Enum to string
"chunks_count": v.get(
"chunks_count", -1
), # Default to -1 if not provided
}
await self.db.execute(sql, _data)
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""
@@ -1103,6 +1129,7 @@ TABLES = {
"ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
workspace varchar(255) NOT NULL,
id varchar(255) NOT NULL,
content TEXT,
content_summary varchar(255) NULL,
content_length int4 NULL,
chunks_count int4 NULL,

View File

@@ -632,7 +632,7 @@ class LightRAG:
]
try:
await asyncio.gather(*tasks)
await self.doc_status.upsert(
await self.doc_status.update_doc_status(
{
doc_status_id: {
"status": DocStatus.PROCESSED,
@@ -649,7 +649,7 @@ class LightRAG:
except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert(
await self.doc_status.update_doc_status(
{
doc_status_id: {
"status": DocStatus.FAILED,