Merge branch 'main' into select-datastore-in-api-server

This commit is contained in:
yangdx
2025-02-12 09:49:18 +08:00
11 changed files with 320 additions and 59 deletions

View File

@@ -85,7 +85,7 @@ Use the below Python snippet (in a script) to initialize LightRAG and perform qu
```python ```python
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
@@ -95,12 +95,12 @@ from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
) )
@@ -355,16 +355,26 @@ In order to run this experiment on low RAM GPU you should select small model and
```python ```python
class QueryParam: class QueryParam:
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global" mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
"""Specifies the retrieval mode:
- "local": Focuses on context-dependent information.
- "global": Utilizes global knowledge.
- "hybrid": Combines local and global retrieval methods.
- "naive": Performs a basic search without advanced techniques.
- "mix": Integrates knowledge graph and vector retrieval.
"""
only_need_context: bool = False only_need_context: bool = False
"""If True, only returns the retrieved context without generating a response."""
response_type: str = "Multiple Paragraphs" response_type: str = "Multiple Paragraphs"
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
top_k: int = 60 top_k: int = 60
# Number of tokens for the original chunks. """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = 4000 max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000 max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000 max_token_for_local_context: int = 4000
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
...
``` ```
> default value of Top_k can be change by environment variables TOP_K. > default value of Top_k can be change by environment variables TOP_K.

View File

@@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
print(f"BASE_URL: {BASE_URL}")
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
print(f"API_KEY: {API_KEY}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
@@ -36,10 +40,12 @@ async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
return await openai_complete_if_cache( return await openai_complete_if_cache(
LLM_MODEL, model=LLM_MODEL,
prompt, prompt=prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
base_url=BASE_URL,
api_key=API_KEY,
**kwargs, **kwargs,
) )
@@ -49,8 +55,10 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray: async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed( return await openai_embed(
texts, texts=texts,
model=EMBEDDING_MODEL, model=EMBEDDING_MODEL,
base_url=BASE_URL,
api_key=API_KEY,
) )

View File

@@ -0,0 +1,101 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
DEFAULT_RAG_DIR = "index_default"
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
print(f"BASE_URL: {BASE_URL}")
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
print(f"API_KEY: {API_KEY}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# LLM model function
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
model=LLM_MODEL,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=BASE_URL,
api_key=API_KEY,
**kwargs,
)
# Embedding function
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts=texts,
model=EMBEDDING_MODEL,
base_url=BASE_URL,
api_key=API_KEY,
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"{embedding_dim=}")
return embedding_dim
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=asyncio.run(get_embedding_dim()),
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,7 +1,7 @@
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
@@ -10,6 +10,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete, llm_model_func=gpt_4o_mini_complete,
# llm_model_func=gpt_4o_complete # llm_model_func=gpt_4o_complete
) )

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.1.5" __version__ = "1.1.6"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -27,30 +27,54 @@ T = TypeVar("T")
@dataclass @dataclass
class QueryParam: class QueryParam:
"""Configuration parameters for query execution in LightRAG."""
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global" mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
"""Specifies the retrieval mode:
- "local": Focuses on context-dependent information.
- "global": Utilizes global knowledge.
- "hybrid": Combines local and global retrieval methods.
- "naive": Performs a basic search without advanced techniques.
- "mix": Integrates knowledge graph and vector retrieval.
"""
only_need_context: bool = False only_need_context: bool = False
"""If True, only returns the retrieved context without generating a response."""
only_need_prompt: bool = False only_need_prompt: bool = False
"""If True, only returns the generated prompt without producing a response."""
response_type: str = "Multiple Paragraphs" response_type: str = "Multiple Paragraphs"
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
stream: bool = False stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. """If True, enables streaming output for real-time responses."""
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
# Number of document chunks to retrieve. """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
# top_n: int = 10
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000 max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000 max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000 max_token_for_local_context: int = 4000
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
hl_keywords: list[str] = field(default_factory=list) hl_keywords: list[str] = field(default_factory=list)
"""List of high-level keywords to prioritize in retrieval."""
ll_keywords: list[str] = field(default_factory=list) ll_keywords: list[str] = field(default_factory=list)
# Conversation history support """List of low-level keywords to refine retrieval focus."""
conversation_history: list[dict[str, str]] = field(
default_factory=list conversation_history: list[dict[str, Any]] = field(default_factory=list)
) # Format: [{"role": "user/assistant", "content": "message"}] """Stores past conversation history to maintain context.
history_turns: int = ( Format: [{"role": "user/assistant", "content": "message"}].
3 # Number of complete conversation turns (user-assistant pairs) to consider """
)
history_turns: int = 3
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
@dataclass @dataclass
@@ -202,3 +226,7 @@ class DocStatusStorage(BaseKVStorage):
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents""" """Get all pending documents"""
raise NotImplementedError raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
"""Updates the status of a document. By default, it calls upsert."""
await self.upsert(data)

View File

@@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage):
if v["status"] == DocStatus.PENDING if v["status"] == DocStatus.PENDING
} }
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSED
}
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING
}
async def index_done_callback(self): async def index_done_callback(self):
"""Save data to file after indexing""" """Save data to file after indexing"""
write_json(self._data, self._file_name) write_json(self._data, self._file_name)

View File

@@ -530,6 +530,32 @@ class PGDocStatusStorage(DocStatusStorage):
) )
return data return data
async def update_doc_status(self, data: dict[str, dict]) -> None:
"""
Updates only the document status, chunk count, and updated timestamp.
This method ensures that only relevant fields are updated instead of overwriting
the entire document record. If `updated_at` is not provided, the database will
automatically use the current timestamp.
"""
sql = """
UPDATE LIGHTRAG_DOC_STATUS
SET status = $3,
chunks_count = $4,
updated_at = CURRENT_TIMESTAMP
WHERE workspace = $1 AND id = $2
"""
for k, v in data.items():
_data = {
"workspace": self.db.workspace,
"id": k,
"status": v["status"].value, # Convert Enum to string
"chunks_count": v.get(
"chunks_count", -1
), # Default to -1 if not provided
}
await self.db.execute(sql, _data)
class PGGraphQueryException(Exception): class PGGraphQueryException(Exception):
"""Exception for the AGE queries.""" """Exception for the AGE queries."""

View File

@@ -211,38 +211,65 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@dataclass @dataclass
class LightRAG: class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
) )
# Default not to use embedding cache """Directory where cache and temporary files are stored."""
embedding_cache_config: dict = field(
embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: { default_factory=lambda: {
"enabled": False, "enabled": False,
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
"use_llm_check": False, "use_llm_check": False,
} }
) )
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
kv_storage: str = field(default="JsonKVStorage") kv_storage: str = field(default="JsonKVStorage")
"""Storage backend for key-value data."""
vector_storage: str = field(default="NanoVectorDBStorage") vector_storage: str = field(default="NanoVectorDBStorage")
"""Storage backend for vector embeddings."""
graph_storage: str = field(default="NetworkXStorage") graph_storage: str = field(default="NetworkXStorage")
"""Storage backend for knowledge graphs."""
# logging # Logging
current_log_level = logger.level current_log_level = logger.level
log_level: str = field(default=current_log_level) log_level: int = field(default=current_log_level)
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
log_dir: str = field(default=os.getcwd()) log_dir: str = field(default=os.getcwd())
"""Directory where logs are stored. Defaults to the current working directory."""
# text chunking # Text chunking
chunk_token_size: int = 1200 chunk_token_size: int = 1200
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = 100 chunk_overlap_token_size: int = 100
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini" tiktoken_model_name: str = "gpt-4o-mini"
"""Model name used for tokenization when chunking text."""
# entity extraction # Entity extraction
entity_extract_max_gleaning: int = 1 entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500 """Maximum number of entity extraction attempts for ambiguous content."""
# node embedding entity_summary_to_max_tokens: int = 500
"""Maximum number of tokens used for summarizing extracted entities."""
# Node embedding
node_embedding_algorithm: str = "node2vec" node_embedding_algorithm: str = "node2vec"
node2vec_params: dict = field( """Algorithm used for node embedding in knowledge graphs."""
node2vec_params: dict[str, int] = field(
default_factory=lambda: { default_factory=lambda: {
"dimensions": 1536, "dimensions": 1536,
"num_walks": 10, "num_walks": 10,
@@ -252,26 +279,56 @@ class LightRAG:
"random_seed": 3, "random_seed": 3,
} }
) )
"""Configuration for the node2vec embedding algorithm:
- dimensions: Number of dimensions for embeddings.
- num_walks: Number of random walks per node.
- walk_length: Number of steps per random walk.
- window_size: Context window size for training.
- iterations: Number of iterations for training.
- random_seed: Seed value for reproducibility.
"""
embedding_func: EmbeddingFunc = None
"""Function for computing text embeddings. Must be set before use."""
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
embedding_batch_num: int = 32 embedding_batch_num: int = 32
"""Batch size for embedding computations."""
embedding_func_max_async: int = 16 embedding_func_max_async: int = 16
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
llm_model_func: callable = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
"""Name of the LLM model used for generating responses."""
# LLM
llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768")) llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) """Maximum number of tokens allowed per LLM response."""
llm_model_kwargs: dict = field(default_factory=dict)
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
"""Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
# Storage
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage."""
# storage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
namespace_prefix: str = field(default="") namespace_prefix: str = field(default="")
"""Prefix for namespacing stored data across different environments."""
enable_llm_cache: bool = True enable_llm_cache: bool = True
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag """Enables caching for LLM responses to avoid redundant computations."""
enable_llm_cache_for_entity_extract: bool = True enable_llm_cache_for_entity_extract: bool = True
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
# extension # extension
addon_params: dict[str, Any] = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
@@ -279,8 +336,8 @@ class LightRAG:
convert_response_to_json convert_response_to_json
) )
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage") doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Custom Chunking Function # Custom Chunking Function
chunking_func: Callable[ chunking_func: Callable[
@@ -799,7 +856,7 @@ class LightRAG:
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids} new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
if not new_docs: if not new_docs:
logger.info("All documents have been processed or are duplicates") logger.info("No new unique documents were found.")
return return
# 4. Store status document # 4. Store status document
@@ -816,15 +873,16 @@ class LightRAG:
each chunk for entity and relation extraction, and updating the each chunk for entity and relation extraction, and updating the
document status. document status.
1. Get all pending and failed documents 1. Get all pending, failed, and abnormally terminated processing documents.
2. Split document content into chunks 2. Split document content into chunks
3. Process each chunk for entity and relation extraction 3. Process each chunk for entity and relation extraction
4. Update the document status 4. Update the document status
""" """
# 1. get all pending and failed documents # 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {} to_process_docs: dict[str, DocProcessingStatus] = {}
# Fetch failed documents processing_docs = await self.doc_status.get_processing_docs()
to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_failed_docs() failed_docs = await self.doc_status.get_failed_docs()
to_process_docs.update(failed_docs) to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs() pendings_docs = await self.doc_status.get_pending_docs()
@@ -855,6 +913,7 @@ class LightRAG:
doc_status_id: { doc_status_id: {
"status": DocStatus.PROCESSING, "status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,
"content_length": status_doc.content_length, "content_length": status_doc.content_length,
"created_at": status_doc.created_at, "created_at": status_doc.created_at,
@@ -886,11 +945,15 @@ class LightRAG:
] ]
try: try:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
await self.doc_status.upsert( await self.doc_status.update_doc_status(
{ {
doc_status_id: { doc_status_id: {
"status": DocStatus.PROCESSED, "status": DocStatus.PROCESSED,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
} }
} }
@@ -899,11 +962,15 @@ class LightRAG:
except Exception as e: except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}") logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert( await self.doc_status.update_doc_status(
{ {
doc_status_id: { doc_status_id: {
"status": DocStatus.FAILED, "status": DocStatus.FAILED,
"error": str(e), "error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
} }
} }

View File

@@ -103,17 +103,19 @@ async def openai_complete_if_cache(
) -> str: ) -> str:
if history_messages is None: if history_messages is None:
history_messages = [] history_messages = []
if api_key: if not api_key:
os.environ["OPENAI_API_KEY"] = api_key api_key = os.environ["OPENAI_API_KEY"]
default_headers = { default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
openai_async_client = ( openai_async_client = (
AsyncOpenAI(default_headers=default_headers) AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers) else AsyncOpenAI(
base_url=base_url, default_headers=default_headers, api_key=api_key
)
) )
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)
@@ -294,17 +296,19 @@ async def openai_embed(
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
) -> np.ndarray: ) -> np.ndarray:
if api_key: if not api_key:
os.environ["OPENAI_API_KEY"] = api_key api_key = os.environ["OPENAI_API_KEY"]
default_headers = { default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
openai_async_client = ( openai_async_client = (
AsyncOpenAI(default_headers=default_headers) AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers) else AsyncOpenAI(
base_url=base_url, default_headers=default_headers, api_key=api_key
)
) )
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"

View File

@@ -1504,7 +1504,7 @@ async def naive_query(
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "default", cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
) )
if cached_response is not None: if cached_response is not None:
return cached_response return cached_response