fixed bugs
This commit is contained in:
@@ -90,7 +90,7 @@ class BaseKVStorage(StorageNameSpace):
|
|||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""return un-exist keys"""
|
"""return un-exist keys"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -38,7 +38,7 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
for id in ids
|
for id in ids
|
||||||
]
|
]
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
return set([s for s in data if s not in self._data])
|
return set([s for s in data if s not in self._data])
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
@@ -52,17 +52,16 @@ import os
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from lightrag.utils import (
|
|
||||||
logger,
|
|
||||||
load_json,
|
|
||||||
write_json,
|
|
||||||
)
|
|
||||||
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
DocStatus,
|
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
|
DocStatus,
|
||||||
DocStatusStorage,
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
|
from lightrag.utils import (
|
||||||
|
load_json,
|
||||||
|
logger,
|
||||||
|
write_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -75,15 +74,17 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||||
return set(
|
return {k for k, _ in self._data.items() if k in data}
|
||||||
[
|
|
||||||
k
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
for k in data
|
result: list[dict[str, Any]] = []
|
||||||
if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED
|
for id in ids:
|
||||||
]
|
data = self._data.get(id, None)
|
||||||
)
|
if data:
|
||||||
|
result.append(data)
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
@@ -94,11 +95,19 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all failed documents"""
|
"""Get all failed documents"""
|
||||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.FAILED
|
||||||
|
}
|
||||||
|
|
||||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all pending documents"""
|
"""Get all pending documents"""
|
||||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.PENDING
|
||||||
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""Save data to file after indexing"""
|
"""Save data to file after indexing"""
|
||||||
@@ -118,7 +127,11 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
||||||
"""Get document status by ID"""
|
"""Get document status by ID"""
|
||||||
return self._data.get(doc_id)
|
data = self._data.get(doc_id)
|
||||||
|
if data:
|
||||||
|
return DocProcessingStatus(**data)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
async def delete(self, doc_ids: list[str]):
|
async def delete(self, doc_ids: list[str]):
|
||||||
"""Delete document status by IDs"""
|
"""Delete document status by IDs"""
|
||||||
|
@@ -35,7 +35,7 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
return list(self._data.find({"_id": {"$in": ids}}))
|
return list(self._data.find({"_id": {"$in": ids}}))
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
existing_ids = [
|
existing_ids = [
|
||||||
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
|
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
|
||||||
]
|
]
|
||||||
|
@@ -421,7 +421,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
keys = ",".join([f"'{_id}'" for _id in data])
|
keys = ",".join([f"'{_id}'" for _id in data])
|
||||||
sql = (
|
sql = (
|
||||||
|
@@ -32,7 +32,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
return [json.loads(result) if result else None for result in results]
|
return [json.loads(result) if result else None for result in results]
|
||||||
|
|
||||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
pipe = self._redis.pipeline()
|
pipe = self._redis.pipeline()
|
||||||
for key in data:
|
for key in data:
|
||||||
pipe.exists(f"{self.namespace}:{key}")
|
pipe.exists(f"{self.namespace}:{key}")
|
||||||
|
@@ -1,28 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from collections.abc import Coroutine
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Coroutine, Optional, Type, Union, cast
|
from typing import Any, Callable, Optional, Type, Union, cast
|
||||||
from .operate import (
|
|
||||||
chunking_by_token_size,
|
|
||||||
extract_entities,
|
|
||||||
extract_keywords_only,
|
|
||||||
kg_query,
|
|
||||||
kg_query_with_keywords,
|
|
||||||
mix_kg_vector_query,
|
|
||||||
naive_query,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
EmbeddingFunc,
|
|
||||||
compute_mdhash_id,
|
|
||||||
limit_async_func_call,
|
|
||||||
convert_response_to_json,
|
|
||||||
logger,
|
|
||||||
set_logger,
|
|
||||||
)
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
@@ -33,10 +16,25 @@ from .base import (
|
|||||||
QueryParam,
|
QueryParam,
|
||||||
StorageNameSpace,
|
StorageNameSpace,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .namespace import NameSpace, make_namespace
|
from .namespace import NameSpace, make_namespace
|
||||||
|
from .operate import (
|
||||||
|
chunking_by_token_size,
|
||||||
|
extract_entities,
|
||||||
|
extract_keywords_only,
|
||||||
|
kg_query,
|
||||||
|
kg_query_with_keywords,
|
||||||
|
mix_kg_vector_query,
|
||||||
|
naive_query,
|
||||||
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP
|
from .prompt import GRAPH_FIELD_SEP
|
||||||
|
from .utils import (
|
||||||
|
EmbeddingFunc,
|
||||||
|
compute_mdhash_id,
|
||||||
|
convert_response_to_json,
|
||||||
|
limit_async_func_call,
|
||||||
|
logger,
|
||||||
|
set_logger,
|
||||||
|
)
|
||||||
|
|
||||||
STORAGES = {
|
STORAGES = {
|
||||||
"NetworkXStorage": ".kg.networkx_impl",
|
"NetworkXStorage": ".kg.networkx_impl",
|
||||||
@@ -67,7 +65,6 @@ STORAGES = {
|
|||||||
|
|
||||||
def lazy_external_import(module_name: str, class_name: str):
|
def lazy_external_import(module_name: str, class_name: str):
|
||||||
"""Lazily import a class from an external module based on the package of the caller."""
|
"""Lazily import a class from an external module based on the package of the caller."""
|
||||||
|
|
||||||
# Get the caller's module and package
|
# Get the caller's module and package
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -113,7 +110,7 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class LightRAG:
|
class LightRAG:
|
||||||
working_dir: str = field(
|
working_dir: str = field(
|
||||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
|
||||||
)
|
)
|
||||||
# Default not to use embedding cache
|
# Default not to use embedding cache
|
||||||
embedding_cache_config: dict = field(
|
embedding_cache_config: dict = field(
|
||||||
@@ -496,15 +493,15 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 3. Filter out already processed documents
|
# 3. Filter out already processed documents
|
||||||
add_doc_keys: set[str] = set()
|
new_doc_keys: set[str] = set()
|
||||||
# Get docs ids
|
# Get docs ids
|
||||||
in_process_keys = list(new_docs.keys())
|
in_process_keys = set(new_docs.keys())
|
||||||
# Get in progress docs ids
|
# Get in progress docs ids
|
||||||
excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
|
excluded_ids = await self.doc_status.filter_keys(list(in_process_keys))
|
||||||
# Exclude already in process
|
# Exclude already in process
|
||||||
add_doc_keys = new_docs.keys() - excluded_ids
|
new_doc_keys = in_process_keys - excluded_ids
|
||||||
# Filter
|
# Filter
|
||||||
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
new_docs = {doc_id: new_docs[doc_id] for doc_id in new_doc_keys}
|
||||||
|
|
||||||
if not new_docs:
|
if not new_docs:
|
||||||
logger.info("All documents have been processed or are duplicates")
|
logger.info("All documents have been processed or are duplicates")
|
||||||
@@ -562,15 +559,12 @@ class LightRAG:
|
|||||||
|
|
||||||
# 3. iterate over batches
|
# 3. iterate over batches
|
||||||
tasks: dict[str, list[Coroutine[Any, Any, None]]] = {}
|
tasks: dict[str, list[Coroutine[Any, Any, None]]] = {}
|
||||||
for batch_idx, ids_doc_processing_status in tqdm_async(
|
|
||||||
enumerate(batch_docs_list),
|
logger.info(f"Number of batches to process: {len(batch_docs_list)}.")
|
||||||
desc="Process Batches",
|
|
||||||
):
|
for batch_idx, ids_doc_processing_status in enumerate(batch_docs_list):
|
||||||
# 4. iterate over batch
|
# 4. iterate over batch
|
||||||
for id_doc_processing_status in tqdm_async(
|
for id_doc_processing_status in ids_doc_processing_status:
|
||||||
ids_doc_processing_status,
|
|
||||||
desc=f"Process Batch {batch_idx}",
|
|
||||||
):
|
|
||||||
id_doc, status_doc = id_doc_processing_status
|
id_doc, status_doc = id_doc_processing_status
|
||||||
# Update status in processing
|
# Update status in processing
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
@@ -644,6 +638,7 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
logger.info(f"Completed batch {batch_idx + 1} of {len(batch_docs_list)}.")
|
||||||
|
|
||||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -895,7 +890,6 @@ class LightRAG:
|
|||||||
1. Extract keywords from the 'query' using new function in operate.py.
|
1. Extract keywords from the 'query' using new function in operate.py.
|
||||||
2. Then run the standard aquery() flow with the final prompt (formatted_question).
|
2. Then run the standard aquery() flow with the final prompt (formatted_question).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.aquery_with_separate_keyword_extraction(query, prompt, param)
|
self.aquery_with_separate_keyword_extraction(query, prompt, param)
|
||||||
@@ -908,7 +902,6 @@ class LightRAG:
|
|||||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
# STEP 1: Keyword Extraction
|
# STEP 1: Keyword Extraction
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
Reference in New Issue
Block a user