Fix cache bugs

This commit is contained in:
zrguo
2025-02-11 13:28:18 +08:00
parent 24e0f0390e
commit 2d2ed19095
4 changed files with 49 additions and 15 deletions

View File

@@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
BASE_URL = int(os.environ.get("BASE_URL", "https://api.openai.com/v1"))
print(f"BASE_URL: {BASE_URL}")
API_KEY = int(os.environ.get("API_KEY", "xxxxxxxx"))
print(f"API_KEY: {API_KEY}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -36,10 +40,12 @@ async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
LLM_MODEL,
prompt,
model=LLM_MODEL,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=BASE_URL,
api_key=API_KEY,
**kwargs,
)
@@ -49,8 +55,10 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
texts=texts,
model=EMBEDDING_MODEL,
base_url=BASE_URL,
api_key=API_KEY,
)

View File

@@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage):
if v["status"] == DocStatus.PENDING
}
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSED
}
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING
}
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)

View File

@@ -543,7 +543,7 @@ class LightRAG:
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
if not new_docs:
logger.info("All documents have been processed or are duplicates")
logger.info("No new unique documents were found.")
return
# 4. Store status document
@@ -560,15 +560,16 @@ class LightRAG:
each chunk for entity and relation extraction, and updating the
document status.
1. Get all pending and failed documents
1. Get all pending, failed, and abnormally terminated processing documents.
2. Split document content into chunks
3. Process each chunk for entity and relation extraction
4. Update the document status
"""
# 1. get all pending and failed documents
# 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {}
# Fetch failed documents
processing_docs = await self.doc_status.get_processing_docs()
to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_failed_docs()
to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs()
@@ -599,6 +600,7 @@ class LightRAG:
doc_status_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
@@ -635,6 +637,10 @@ class LightRAG:
doc_status_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
@@ -648,6 +654,10 @@ class LightRAG:
doc_status_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}

View File

@@ -103,17 +103,17 @@ async def openai_complete_if_cache(
) -> str:
if history_messages is None:
history_messages = []
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
AsyncOpenAI(default_headers=default_headers)
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
@@ -294,17 +294,17 @@ async def openai_embed(
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
AsyncOpenAI(default_headers=default_headers)
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"