diff --git a/README.md b/README.md index 850cacd3..62f21a65 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Use the below Python snippet (in a script) to initialize LightRAG and perform qu ```python import os from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete +from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() @@ -95,12 +95,12 @@ from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete WORKING_DIR = "./dickens" - if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) rag = LightRAG( working_dir=WORKING_DIR, + embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model # llm_model_func=gpt_4o_complete # Optionally, use a stronger model ) @@ -355,16 +355,26 @@ In order to run this experiment on low RAM GPU you should select small model and ```python class QueryParam: mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global" + """Specifies the retrieval mode: + - "local": Focuses on context-dependent information. + - "global": Utilizes global knowledge. + - "hybrid": Combines local and global retrieval methods. + - "naive": Performs a basic search without advanced techniques. + - "mix": Integrates knowledge graph and vector retrieval. + """ only_need_context: bool = False + """If True, only returns the retrieved context without generating a response.""" response_type: str = "Multiple Paragraphs" - # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" top_k: int = 60 - # Number of tokens for the original chunks. + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" max_token_for_text_unit: int = 4000 - # Number of tokens for the relationship descriptions + """Maximum number of tokens allowed for each retrieved text chunk.""" max_token_for_global_context: int = 4000 - # Number of tokens for the entity descriptions + """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" max_token_for_local_context: int = 4000 + """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + ... ``` > default value of Top_k can be change by environment variables TOP_K. diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index 8173dc5b..e2d63e41 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") +BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1") +print(f"BASE_URL: {BASE_URL}") +API_KEY = os.environ.get("API_KEY", "xxxxxxxx") +print(f"API_KEY: {API_KEY}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -36,10 +40,12 @@ async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: return await openai_complete_if_cache( - LLM_MODEL, - prompt, + model=LLM_MODEL, + prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, + base_url=BASE_URL, + api_key=API_KEY, **kwargs, ) @@ -49,8 +55,10 @@ async def llm_model_func( async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embed( - texts, + texts=texts, model=EMBEDDING_MODEL, + base_url=BASE_URL, + api_key=API_KEY, ) diff --git a/examples/lightrag_api_openai_compatible_demo_simplified.py b/examples/lightrag_api_openai_compatible_demo_simplified.py new file mode 100644 index 00000000..fabbb3e2 --- /dev/null +++ b/examples/lightrag_api_openai_compatible_demo_simplified.py @@ -0,0 +1,101 @@ +import os +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc +import numpy as np +import asyncio +import nest_asyncio + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +DEFAULT_RAG_DIR = "index_default" + +# Configure working directory +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") +LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini") +print(f"LLM_MODEL: {LLM_MODEL}") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small") +print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) +print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") +BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1") +print(f"BASE_URL: {BASE_URL}") +API_KEY = os.environ.get("API_KEY", "xxxxxxxx") +print(f"API_KEY: {API_KEY}") + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +# LLM model function + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + model=LLM_MODEL, + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + base_url=BASE_URL, + api_key=API_KEY, + **kwargs, + ) + + +# Embedding function + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embed( + texts=texts, + model=EMBEDDING_MODEL, + base_url=BASE_URL, + api_key=API_KEY, + ) + + +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + print(f"{embedding_dim=}") + return embedding_dim + + +# Initialize RAG instance +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=asyncio.run(get_embedding_dim()), + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, + func=embedding_func, + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index 7a43a710..c5393fc8 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -1,7 +1,7 @@ import os from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed WORKING_DIR = "./dickens" @@ -10,6 +10,7 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG( working_dir=WORKING_DIR, + embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, # llm_model_func=gpt_4o_complete ) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index d68bded0..031502d6 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.5" +__version__ = "1.1.6" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/base.py b/lightrag/base.py index 1a7f9c2e..bd79d990 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -27,30 +27,54 @@ T = TypeVar("T") @dataclass class QueryParam: + """Configuration parameters for query execution in LightRAG.""" + mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global" + """Specifies the retrieval mode: + - "local": Focuses on context-dependent information. + - "global": Utilizes global knowledge. + - "hybrid": Combines local and global retrieval methods. + - "naive": Performs a basic search without advanced techniques. + - "mix": Integrates knowledge graph and vector retrieval. + """ + only_need_context: bool = False + """If True, only returns the retrieved context without generating a response.""" + only_need_prompt: bool = False + """If True, only returns the generated prompt without producing a response.""" + response_type: str = "Multiple Paragraphs" + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" + stream: bool = False - # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. + """If True, enables streaming output for real-time responses.""" + top_k: int = int(os.getenv("TOP_K", "60")) - # Number of document chunks to retrieve. - # top_n: int = 10 - # Number of tokens for the original chunks. + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + max_token_for_text_unit: int = 4000 - # Number of tokens for the relationship descriptions + """Maximum number of tokens allowed for each retrieved text chunk.""" + max_token_for_global_context: int = 4000 - # Number of tokens for the entity descriptions + """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" + max_token_for_local_context: int = 4000 + """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + hl_keywords: list[str] = field(default_factory=list) + """List of high-level keywords to prioritize in retrieval.""" + ll_keywords: list[str] = field(default_factory=list) - # Conversation history support - conversation_history: list[dict[str, str]] = field( - default_factory=list - ) # Format: [{"role": "user/assistant", "content": "message"}] - history_turns: int = ( - 3 # Number of complete conversation turns (user-assistant pairs) to consider - ) + """List of low-level keywords to refine retrieval focus.""" + + conversation_history: list[dict[str, Any]] = field(default_factory=list) + """Stores past conversation history to maintain context. + Format: [{"role": "user/assistant", "content": "message"}]. + """ + + history_turns: int = 3 + """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" @dataclass @@ -202,3 +226,7 @@ class DocStatusStorage(BaseKVStorage): async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: """Get all pending documents""" raise NotImplementedError + + async def update_doc_status(self, data: dict[str, Any]) -> None: + """Updates the status of a document. By default, it calls upsert.""" + await self.upsert(data) diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index 675cf643..fad03acc 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage): if v["status"] == DocStatus.PENDING } + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processed documents""" + return { + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == DocStatus.PROCESSED + } + + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + return { + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == DocStatus.PROCESSING + } + async def index_done_callback(self): """Save data to file after indexing""" write_json(self._data, self._file_name) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 5bd0a949..526f54a7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -530,6 +530,32 @@ class PGDocStatusStorage(DocStatusStorage): ) return data + async def update_doc_status(self, data: dict[str, dict]) -> None: + """ + Updates only the document status, chunk count, and updated timestamp. + + This method ensures that only relevant fields are updated instead of overwriting + the entire document record. If `updated_at` is not provided, the database will + automatically use the current timestamp. + """ + sql = """ + UPDATE LIGHTRAG_DOC_STATUS + SET status = $3, + chunks_count = $4, + updated_at = CURRENT_TIMESTAMP + WHERE workspace = $1 AND id = $2 + """ + for k, v in data.items(): + _data = { + "workspace": self.db.workspace, + "id": k, + "status": v["status"].value, # Convert Enum to string + "chunks_count": v.get( + "chunks_count", -1 + ), # Default to -1 if not provided + } + await self.db.execute(sql, _data) + class PGGraphQueryException(Exception): """Exception for the AGE queries.""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b1670850..cb1aa195 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -211,38 +211,65 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: @dataclass class LightRAG: + """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" + working_dir: str = field( default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' ) - # Default not to use embedding cache - embedding_cache_config: dict = field( + """Directory where cache and temporary files are stored.""" + + embedding_cache_config: dict[str, Any] = field( default_factory=lambda: { "enabled": False, "similarity_threshold": 0.95, "use_llm_check": False, } ) + """Configuration for embedding cache. + - enabled: If True, enables caching to avoid redundant computations. + - similarity_threshold: Minimum similarity score to use cached embeddings. + - use_llm_check: If True, validates cached embeddings using an LLM. + """ + kv_storage: str = field(default="JsonKVStorage") + """Storage backend for key-value data.""" + vector_storage: str = field(default="NanoVectorDBStorage") + """Storage backend for vector embeddings.""" + graph_storage: str = field(default="NetworkXStorage") + """Storage backend for knowledge graphs.""" - # logging + # Logging current_log_level = logger.level - log_level: str = field(default=current_log_level) + log_level: int = field(default=current_log_level) + """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" + log_dir: str = field(default=os.getcwd()) + """Directory where logs are stored. Defaults to the current working directory.""" - # text chunking + # Text chunking chunk_token_size: int = 1200 + """Maximum number of tokens per text chunk when splitting documents.""" + chunk_overlap_token_size: int = 100 + """Number of overlapping tokens between consecutive text chunks to preserve context.""" + tiktoken_model_name: str = "gpt-4o-mini" + """Model name used for tokenization when chunking text.""" - # entity extraction + # Entity extraction entity_extract_max_gleaning: int = 1 - entity_summary_to_max_tokens: int = 500 + """Maximum number of entity extraction attempts for ambiguous content.""" - # node embedding + entity_summary_to_max_tokens: int = 500 + """Maximum number of tokens used for summarizing extracted entities.""" + + # Node embedding node_embedding_algorithm: str = "node2vec" - node2vec_params: dict = field( + """Algorithm used for node embedding in knowledge graphs.""" + + node2vec_params: dict[str, int] = field( default_factory=lambda: { "dimensions": 1536, "num_walks": 10, @@ -252,26 +279,56 @@ class LightRAG: "random_seed": 3, } ) + """Configuration for the node2vec embedding algorithm: + - dimensions: Number of dimensions for embeddings. + - num_walks: Number of random walks per node. + - walk_length: Number of steps per random walk. + - window_size: Context window size for training. + - iterations: Number of iterations for training. + - random_seed: Seed value for reproducibility. + """ + + embedding_func: EmbeddingFunc = None + """Function for computing text embeddings. Must be set before use.""" - # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) - embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization) embedding_batch_num: int = 32 + """Batch size for embedding computations.""" + embedding_func_max_async: int = 16 + """Maximum number of concurrent embedding function calls.""" + + # LLM Configuration + llm_model_func: callable = None + """Function for interacting with the large language model (LLM). Must be set before use.""" + + llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" + """Name of the LLM model used for generating responses.""" - # LLM - llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization) - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768")) - llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) - llm_model_kwargs: dict = field(default_factory=dict) + """Maximum number of tokens allowed per LLM response.""" + + llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) + """Maximum number of concurrent LLM calls.""" + + llm_model_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional keyword arguments passed to the LLM model function.""" + + # Storage + vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional parameters for vector database storage.""" - # storage - vector_db_storage_cls_kwargs: dict = field(default_factory=dict) namespace_prefix: str = field(default="") + """Prefix for namespacing stored data across different environments.""" enable_llm_cache: bool = True - # Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag + """Enables caching for LLM responses to avoid redundant computations.""" + enable_llm_cache_for_entity_extract: bool = True + """If True, enables caching for entity extraction steps to reduce LLM costs.""" + + # Extensions + addon_params: dict[str, Any] = field(default_factory=dict) + """Dictionary for additional parameters and extensions.""" # extension addon_params: dict[str, Any] = field(default_factory=dict) @@ -279,8 +336,8 @@ class LightRAG: convert_response_to_json ) - # Add new field for document status storage type doc_status_storage: str = field(default="JsonDocStatusStorage") + """Storage type for tracking document processing statuses.""" # Custom Chunking Function chunking_func: Callable[ @@ -799,7 +856,7 @@ class LightRAG: new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids} if not new_docs: - logger.info("All documents have been processed or are duplicates") + logger.info("No new unique documents were found.") return # 4. Store status document @@ -816,15 +873,16 @@ class LightRAG: each chunk for entity and relation extraction, and updating the document status. - 1. Get all pending and failed documents + 1. Get all pending, failed, and abnormally terminated processing documents. 2. Split document content into chunks 3. Process each chunk for entity and relation extraction 4. Update the document status """ - # 1. get all pending and failed documents + # 1. Get all pending, failed, and abnormally terminated processing documents. to_process_docs: dict[str, DocProcessingStatus] = {} - # Fetch failed documents + processing_docs = await self.doc_status.get_processing_docs() + to_process_docs.update(processing_docs) failed_docs = await self.doc_status.get_failed_docs() to_process_docs.update(failed_docs) pendings_docs = await self.doc_status.get_pending_docs() @@ -855,6 +913,7 @@ class LightRAG: doc_status_id: { "status": DocStatus.PROCESSING, "updated_at": datetime.now().isoformat(), + "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, "created_at": status_doc.created_at, @@ -886,11 +945,15 @@ class LightRAG: ] try: await asyncio.gather(*tasks) - await self.doc_status.upsert( + await self.doc_status.update_doc_status( { doc_status_id: { "status": DocStatus.PROCESSED, "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, "updated_at": datetime.now().isoformat(), } } @@ -899,11 +962,15 @@ class LightRAG: except Exception as e: logger.error(f"Failed to process document {doc_id}: {str(e)}") - await self.doc_status.upsert( + await self.doc_status.update_doc_status( { doc_status_id: { "status": DocStatus.FAILED, "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, "updated_at": datetime.now().isoformat(), } } diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 535d665c..e6d00377 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -103,17 +103,19 @@ async def openai_complete_if_cache( ) -> str: if history_messages is None: history_messages = [] - if api_key: - os.environ["OPENAI_API_KEY"] = api_key + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] default_headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) + AsyncOpenAI(default_headers=default_headers, api_key=api_key) if base_url is None - else AsyncOpenAI(base_url=base_url, default_headers=default_headers) + else AsyncOpenAI( + base_url=base_url, default_headers=default_headers, api_key=api_key + ) ) kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) @@ -294,17 +296,19 @@ async def openai_embed( base_url: str = None, api_key: str = None, ) -> np.ndarray: - if api_key: - os.environ["OPENAI_API_KEY"] = api_key + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] default_headers = { "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) + AsyncOpenAI(default_headers=default_headers, api_key=api_key) if base_url is None - else AsyncOpenAI(base_url=base_url, default_headers=default_headers) + else AsyncOpenAI( + base_url=base_url, default_headers=default_headers, api_key=api_key + ) ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" diff --git a/lightrag/operate.py b/lightrag/operate.py index 811b4194..db7f59a5 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1504,7 +1504,7 @@ async def naive_query( use_model_func = global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, query, "default", cache_type="query" + hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response