diff --git a/README.md b/README.md index 62dc032b..cf1d86aa 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
-![]() |
@@ -355,16 +355,26 @@ In order to run this experiment on low RAM GPU you should select small model and
```python
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
+ """Specifies the retrieval mode:
+ - "local": Focuses on context-dependent information.
+ - "global": Utilizes global knowledge.
+ - "hybrid": Combines local and global retrieval methods.
+ - "naive": Performs a basic search without advanced techniques.
+ - "mix": Integrates knowledge graph and vector retrieval.
+ """
only_need_context: bool = False
+ """If True, only returns the retrieved context without generating a response."""
response_type: str = "Multiple Paragraphs"
- # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
+ """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
top_k: int = 60
- # Number of tokens for the original chunks.
+ """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = 4000
- # Number of tokens for the relationship descriptions
+ """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000
- # Number of tokens for the entity descriptions
+ """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000
+ """Maximum number of tokens allocated for entity descriptions in local retrieval."""
+ ...
```
> default value of Top_k can be change by environment variables TOP_K.
diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py
index 8173dc5b..e2d63e41 100644
--- a/examples/lightrag_api_openai_compatible_demo.py
+++ b/examples/lightrag_api_openai_compatible_demo.py
@@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
+BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
+print(f"BASE_URL: {BASE_URL}")
+API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
+print(f"API_KEY: {API_KEY}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -36,10 +40,12 @@ async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
- LLM_MODEL,
- prompt,
+ model=LLM_MODEL,
+ prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ base_url=BASE_URL,
+ api_key=API_KEY,
**kwargs,
)
@@ -49,8 +55,10 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
- texts,
+ texts=texts,
model=EMBEDDING_MODEL,
+ base_url=BASE_URL,
+ api_key=API_KEY,
)
diff --git a/examples/lightrag_api_openai_compatible_demo_simplified.py b/examples/lightrag_api_openai_compatible_demo_simplified.py
new file mode 100644
index 00000000..fabbb3e2
--- /dev/null
+++ b/examples/lightrag_api_openai_compatible_demo_simplified.py
@@ -0,0 +1,101 @@
+import os
+from lightrag import LightRAG, QueryParam
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
+from lightrag.utils import EmbeddingFunc
+import numpy as np
+import asyncio
+import nest_asyncio
+
+# Apply nest_asyncio to solve event loop issues
+nest_asyncio.apply()
+
+DEFAULT_RAG_DIR = "index_default"
+
+# Configure working directory
+WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
+print(f"WORKING_DIR: {WORKING_DIR}")
+LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
+print(f"LLM_MODEL: {LLM_MODEL}")
+EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small")
+print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
+EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
+print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
+BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
+print(f"BASE_URL: {BASE_URL}")
+API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
+print(f"API_KEY: {API_KEY}")
+
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+
+# LLM model function
+
+
+async def llm_model_func(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> str:
+ return await openai_complete_if_cache(
+ model=LLM_MODEL,
+ prompt=prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ base_url=BASE_URL,
+ api_key=API_KEY,
+ **kwargs,
+ )
+
+
+# Embedding function
+
+
+async def embedding_func(texts: list[str]) -> np.ndarray:
+ return await openai_embed(
+ texts=texts,
+ model=EMBEDDING_MODEL,
+ base_url=BASE_URL,
+ api_key=API_KEY,
+ )
+
+
+async def get_embedding_dim():
+ test_text = ["This is a test sentence."]
+ embedding = await embedding_func(test_text)
+ embedding_dim = embedding.shape[1]
+ print(f"{embedding_dim=}")
+ return embedding_dim
+
+
+# Initialize RAG instance
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=asyncio.run(get_embedding_dim()),
+ max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
+ func=embedding_func,
+ ),
+)
+
+with open("./book.txt", "r", encoding="utf-8") as f:
+ rag.insert(f.read())
+
+# Perform naive search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
+
+# Perform local search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
+
+# Perform global search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
+
+# Perform hybrid search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py
index 7a43a710..c5393fc8 100644
--- a/examples/lightrag_openai_demo.py
+++ b/examples/lightrag_openai_demo.py
@@ -1,7 +1,7 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm.openai import gpt_4o_mini_complete
+from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
WORKING_DIR = "./dickens"
@@ -10,6 +10,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
+ embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete,
# llm_model_func=gpt_4o_complete
)
diff --git a/lightrag/__init__.py b/lightrag/__init__.py
index d68bded0..031502d6 100644
--- a/lightrag/__init__.py
+++ b/lightrag/__init__.py
@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
-__version__ = "1.1.5"
+__version__ = "1.1.6"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
diff --git a/lightrag/base.py b/lightrag/base.py
index 1a7f9c2e..0e3f1dc6 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -27,30 +27,54 @@ T = TypeVar("T")
@dataclass
class QueryParam:
+ """Configuration parameters for query execution in LightRAG."""
+
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
+ """Specifies the retrieval mode:
+ - "local": Focuses on context-dependent information.
+ - "global": Utilizes global knowledge.
+ - "hybrid": Combines local and global retrieval methods.
+ - "naive": Performs a basic search without advanced techniques.
+ - "mix": Integrates knowledge graph and vector retrieval.
+ """
+
only_need_context: bool = False
+ """If True, only returns the retrieved context without generating a response."""
+
only_need_prompt: bool = False
+ """If True, only returns the generated prompt without producing a response."""
+
response_type: str = "Multiple Paragraphs"
+ """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
+
stream: bool = False
- # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
+ """If True, enables streaming output for real-time responses."""
+
top_k: int = int(os.getenv("TOP_K", "60"))
- # Number of document chunks to retrieve.
- # top_n: int = 10
- # Number of tokens for the original chunks.
+ """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
+
max_token_for_text_unit: int = 4000
- # Number of tokens for the relationship descriptions
+ """Maximum number of tokens allowed for each retrieved text chunk."""
+
max_token_for_global_context: int = 4000
- # Number of tokens for the entity descriptions
+ """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
+
max_token_for_local_context: int = 4000
+ """Maximum number of tokens allocated for entity descriptions in local retrieval."""
+
hl_keywords: list[str] = field(default_factory=list)
+ """List of high-level keywords to prioritize in retrieval."""
+
ll_keywords: list[str] = field(default_factory=list)
- # Conversation history support
- conversation_history: list[dict[str, str]] = field(
- default_factory=list
- ) # Format: [{"role": "user/assistant", "content": "message"}]
- history_turns: int = (
- 3 # Number of complete conversation turns (user-assistant pairs) to consider
- )
+ """List of low-level keywords to refine retrieval focus."""
+
+ conversation_history: list[dict[str, Any]] = field(default_factory=list)
+ """Stores past conversation history to maintain context.
+ Format: [{"role": "user/assistant", "content": "message"}].
+ """
+
+ history_turns: int = 3
+ """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
@dataclass
diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py
index 675cf643..fad03acc 100644
--- a/lightrag/kg/jsondocstatus_impl.py
+++ b/lightrag/kg/jsondocstatus_impl.py
@@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage):
if v["status"] == DocStatus.PENDING
}
+ async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
+ """Get all processed documents"""
+ return {
+ k: DocProcessingStatus(**v)
+ for k, v in self._data.items()
+ if v["status"] == DocStatus.PROCESSED
+ }
+
+ async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
+ """Get all processing documents"""
+ return {
+ k: DocProcessingStatus(**v)
+ for k, v in self._data.items()
+ if v["status"] == DocStatus.PROCESSING
+ }
+
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 024906dd..0d374976 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -109,38 +109,65 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@dataclass
class LightRAG:
+ """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
+
working_dir: str = field(
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
)
- # Default not to use embedding cache
- embedding_cache_config: dict = field(
+ """Directory where cache and temporary files are stored."""
+
+ embedding_cache_config: dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
+ """Configuration for embedding cache.
+ - enabled: If True, enables caching to avoid redundant computations.
+ - similarity_threshold: Minimum similarity score to use cached embeddings.
+ - use_llm_check: If True, validates cached embeddings using an LLM.
+ """
+
kv_storage: str = field(default="JsonKVStorage")
+ """Storage backend for key-value data."""
+
vector_storage: str = field(default="NanoVectorDBStorage")
+ """Storage backend for vector embeddings."""
+
graph_storage: str = field(default="NetworkXStorage")
+ """Storage backend for knowledge graphs."""
- # logging
+ # Logging
current_log_level = logger.level
- log_level: str = field(default=current_log_level)
+ log_level: int = field(default=current_log_level)
+ """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
+
log_dir: str = field(default=os.getcwd())
+ """Directory where logs are stored. Defaults to the current working directory."""
- # text chunking
+ # Text chunking
chunk_token_size: int = 1200
+ """Maximum number of tokens per text chunk when splitting documents."""
+
chunk_overlap_token_size: int = 100
+ """Number of overlapping tokens between consecutive text chunks to preserve context."""
+
tiktoken_model_name: str = "gpt-4o-mini"
+ """Model name used for tokenization when chunking text."""
- # entity extraction
+ # Entity extraction
entity_extract_max_gleaning: int = 1
- entity_summary_to_max_tokens: int = 500
+ """Maximum number of entity extraction attempts for ambiguous content."""
- # node embedding
+ entity_summary_to_max_tokens: int = 500
+ """Maximum number of tokens used for summarizing extracted entities."""
+
+ # Node embedding
node_embedding_algorithm: str = "node2vec"
- node2vec_params: dict = field(
+ """Algorithm used for node embedding in knowledge graphs."""
+
+ node2vec_params: dict[str, int] = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
@@ -150,26 +177,56 @@ class LightRAG:
"random_seed": 3,
}
)
+ """Configuration for the node2vec embedding algorithm:
+ - dimensions: Number of dimensions for embeddings.
+ - num_walks: Number of random walks per node.
+ - walk_length: Number of steps per random walk.
+ - window_size: Context window size for training.
+ - iterations: Number of iterations for training.
+ - random_seed: Seed value for reproducibility.
+ """
+
+ embedding_func: EmbeddingFunc = None
+ """Function for computing text embeddings. Must be set before use."""
- # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
- embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
embedding_batch_num: int = 32
+ """Batch size for embedding computations."""
+
embedding_func_max_async: int = 16
+ """Maximum number of concurrent embedding function calls."""
+
+ # LLM Configuration
+ llm_model_func: callable = None
+ """Function for interacting with the large language model (LLM). Must be set before use."""
+
+ llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
+ """Name of the LLM model used for generating responses."""
- # LLM
- llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
- llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
- llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
- llm_model_kwargs: dict = field(default_factory=dict)
+ """Maximum number of tokens allowed per LLM response."""
+
+ llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
+ """Maximum number of concurrent LLM calls."""
+
+ llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
+ """Additional keyword arguments passed to the LLM model function."""
+
+ # Storage
+ vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
+ """Additional parameters for vector database storage."""
- # storage
- vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
namespace_prefix: str = field(default="")
+ """Prefix for namespacing stored data across different environments."""
enable_llm_cache: bool = True
- # Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
+ """Enables caching for LLM responses to avoid redundant computations."""
+
enable_llm_cache_for_entity_extract: bool = True
+ """If True, enables caching for entity extraction steps to reduce LLM costs."""
+
+ # Extensions
+ addon_params: dict[str, Any] = field(default_factory=dict)
+ """Dictionary for additional parameters and extensions."""
# extension
addon_params: dict[str, Any] = field(default_factory=dict)
@@ -177,8 +234,8 @@ class LightRAG:
convert_response_to_json
)
- # Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
+ """Storage type for tracking document processing statuses."""
# Custom Chunking Function
chunking_func: Callable[
@@ -486,7 +543,7 @@ class LightRAG:
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
if not new_docs:
- logger.info("All documents have been processed or are duplicates")
+ logger.info("No new unique documents were found.")
return
# 4. Store status document
@@ -503,15 +560,16 @@ class LightRAG:
each chunk for entity and relation extraction, and updating the
document status.
- 1. Get all pending and failed documents
+ 1. Get all pending, failed, and abnormally terminated processing documents.
2. Split document content into chunks
3. Process each chunk for entity and relation extraction
4. Update the document status
"""
- # 1. get all pending and failed documents
+ # 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {}
- # Fetch failed documents
+ processing_docs = await self.doc_status.get_processing_docs()
+ to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_failed_docs()
to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs()
@@ -542,6 +600,7 @@ class LightRAG:
doc_status_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
+ "content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
@@ -578,6 +637,10 @@ class LightRAG:
doc_status_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
+ "content": status_doc.content,
+ "content_summary": status_doc.content_summary,
+ "content_length": status_doc.content_length,
+ "created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
@@ -591,6 +654,10 @@ class LightRAG:
doc_status_id: {
"status": DocStatus.FAILED,
"error": str(e),
+ "content": status_doc.content,
+ "content_summary": status_doc.content_summary,
+ "content_length": status_doc.content_length,
+ "created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index 535d665c..e6d00377 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -103,17 +103,19 @@ async def openai_complete_if_cache(
) -> str:
if history_messages is None:
history_messages = []
- if api_key:
- os.environ["OPENAI_API_KEY"] = api_key
+ if not api_key:
+ api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
- AsyncOpenAI(default_headers=default_headers)
+ AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
- else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
+ else AsyncOpenAI(
+ base_url=base_url, default_headers=default_headers, api_key=api_key
+ )
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
@@ -294,17 +296,19 @@ async def openai_embed(
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
- if api_key:
- os.environ["OPENAI_API_KEY"] = api_key
+ if not api_key:
+ api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
- AsyncOpenAI(default_headers=default_headers)
+ AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
- else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
+ else AsyncOpenAI(
+ base_url=base_url, default_headers=default_headers, api_key=api_key
+ )
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 811b4194..db7f59a5 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -1504,7 +1504,7 @@ async def naive_query(
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
- hashing_kv, args_hash, query, "default", cache_type="query"
+ hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if cached_response is not None:
return cached_response
|