Fix cache bugs
This commit is contained in:
@@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
|
|||||||
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
||||||
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
||||||
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
||||||
|
BASE_URL = int(os.environ.get("BASE_URL", "https://api.openai.com/v1"))
|
||||||
|
print(f"BASE_URL: {BASE_URL}")
|
||||||
|
API_KEY = int(os.environ.get("API_KEY", "xxxxxxxx"))
|
||||||
|
print(f"API_KEY: {API_KEY}")
|
||||||
|
|
||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
@@ -36,10 +40,12 @@ async def llm_model_func(
|
|||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
LLM_MODEL,
|
model=LLM_MODEL,
|
||||||
prompt,
|
prompt=prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
|
base_url=BASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,8 +55,10 @@ async def llm_model_func(
|
|||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await openai_embed(
|
return await openai_embed(
|
||||||
texts,
|
texts=texts,
|
||||||
model=EMBEDDING_MODEL,
|
model=EMBEDDING_MODEL,
|
||||||
|
base_url=BASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
if v["status"] == DocStatus.PENDING
|
if v["status"] == DocStatus.PENDING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""Get all processed documents"""
|
||||||
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.PROCESSED
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""Get all processing documents"""
|
||||||
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.PROCESSING
|
||||||
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""Save data to file after indexing"""
|
"""Save data to file after indexing"""
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
@@ -543,7 +543,7 @@ class LightRAG:
|
|||||||
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
|
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
|
||||||
|
|
||||||
if not new_docs:
|
if not new_docs:
|
||||||
logger.info("All documents have been processed or are duplicates")
|
logger.info("No new unique documents were found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. Store status document
|
# 4. Store status document
|
||||||
@@ -560,15 +560,16 @@ class LightRAG:
|
|||||||
each chunk for entity and relation extraction, and updating the
|
each chunk for entity and relation extraction, and updating the
|
||||||
document status.
|
document status.
|
||||||
|
|
||||||
1. Get all pending and failed documents
|
1. Get all pending, failed, and abnormally terminated processing documents.
|
||||||
2. Split document content into chunks
|
2. Split document content into chunks
|
||||||
3. Process each chunk for entity and relation extraction
|
3. Process each chunk for entity and relation extraction
|
||||||
4. Update the document status
|
4. Update the document status
|
||||||
"""
|
"""
|
||||||
# 1. get all pending and failed documents
|
# 1. Get all pending, failed, and abnormally terminated processing documents.
|
||||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||||
|
|
||||||
# Fetch failed documents
|
processing_docs = await self.doc_status.get_processing_docs()
|
||||||
|
to_process_docs.update(processing_docs)
|
||||||
failed_docs = await self.doc_status.get_failed_docs()
|
failed_docs = await self.doc_status.get_failed_docs()
|
||||||
to_process_docs.update(failed_docs)
|
to_process_docs.update(failed_docs)
|
||||||
pendings_docs = await self.doc_status.get_pending_docs()
|
pendings_docs = await self.doc_status.get_pending_docs()
|
||||||
@@ -599,6 +600,7 @@ class LightRAG:
|
|||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSING,
|
"status": DocStatus.PROCESSING,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
|
"content": status_doc.content,
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
"content_length": status_doc.content_length,
|
"content_length": status_doc.content_length,
|
||||||
"created_at": status_doc.created_at,
|
"created_at": status_doc.created_at,
|
||||||
@@ -635,6 +637,10 @@ class LightRAG:
|
|||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSED,
|
"status": DocStatus.PROCESSED,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
|
"content": status_doc.content,
|
||||||
|
"content_summary": status_doc.content_summary,
|
||||||
|
"content_length": status_doc.content_length,
|
||||||
|
"created_at": status_doc.created_at,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -648,6 +654,10 @@ class LightRAG:
|
|||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.FAILED,
|
"status": DocStatus.FAILED,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"content": status_doc.content,
|
||||||
|
"content_summary": status_doc.content_summary,
|
||||||
|
"content_length": status_doc.content_length,
|
||||||
|
"created_at": status_doc.created_at,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -103,17 +103,17 @@ async def openai_complete_if_cache(
|
|||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
if api_key:
|
if not api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
|
||||||
default_headers = {
|
default_headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
openai_async_client = (
|
openai_async_client = (
|
||||||
AsyncOpenAI(default_headers=default_headers)
|
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||||
if base_url is None
|
if base_url is None
|
||||||
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
|
else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
|
||||||
)
|
)
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
@@ -294,17 +294,17 @@ async def openai_embed(
|
|||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
if api_key:
|
if not api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
|
||||||
default_headers = {
|
default_headers = {
|
||||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
openai_async_client = (
|
openai_async_client = (
|
||||||
AsyncOpenAI(default_headers=default_headers)
|
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||||
if base_url is None
|
if base_url is None
|
||||||
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
|
else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
|
||||||
)
|
)
|
||||||
response = await openai_async_client.embeddings.create(
|
response = await openai_async_client.embeddings.create(
|
||||||
model=model, input=texts, encoding_format="float"
|
model=model, input=texts, encoding_format="float"
|
||||||
|
Reference in New Issue
Block a user