From d0779209d937f274478f9bddc8a11367a329495b Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Mon, 10 Feb 2025 14:01:52 +0800 Subject: [PATCH 01/12] Update LOGO --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 62dc032b..850cacd3 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
-lightrag +lightrag
From e1f4f9560da021fb4b731baf119889873728e71d Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 9 Feb 2025 00:13:26 +0100 Subject: [PATCH 02/12] updated documentation --- lightrag/base.py | 71 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 1a7f9c2e..ae5ce92e 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -27,31 +27,54 @@ T = TypeVar("T") @dataclass class QueryParam: - mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global" - only_need_context: bool = False - only_need_prompt: bool = False - response_type: str = "Multiple Paragraphs" - stream: bool = False - # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. - top_k: int = int(os.getenv("TOP_K", "60")) - # Number of document chunks to retrieve. - # top_n: int = 10 - # Number of tokens for the original chunks. - max_token_for_text_unit: int = 4000 - # Number of tokens for the relationship descriptions - max_token_for_global_context: int = 4000 - # Number of tokens for the entity descriptions - max_token_for_local_context: int = 4000 - hl_keywords: list[str] = field(default_factory=list) - ll_keywords: list[str] = field(default_factory=list) - # Conversation history support - conversation_history: list[dict[str, str]] = field( - default_factory=list - ) # Format: [{"role": "user/assistant", "content": "message"}] - history_turns: int = ( - 3 # Number of complete conversation turns (user-assistant pairs) to consider - ) + """Configuration parameters for query execution in LightRAG.""" + mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global" + """Specifies the retrieval mode: + - "local": Focuses on context-dependent information. + - "global": Utilizes global knowledge. + - "hybrid": Combines local and global retrieval methods. + - "naive": Performs a basic search without advanced techniques. + - "mix": Integrates knowledge graph and vector retrieval. + """ + + only_need_context: bool = False + """If True, only returns the retrieved context without generating a response.""" + + only_need_prompt: bool = False + """If True, only returns the generated prompt without producing a response.""" + + response_type: str = "Multiple Paragraphs" + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" + + stream: bool = False + """If True, enables streaming output for real-time responses.""" + + top_k: int = int(os.getenv("TOP_K", "60")) + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + + max_token_for_text_unit: int = 4000 + """Maximum number of tokens allowed for each retrieved text chunk.""" + + max_token_for_global_context: int = 4000 + """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" + + max_token_for_local_context: int = 4000 + """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + + hl_keywords: List[str] = field(default_factory=list) + """List of high-level keywords to prioritize in retrieval.""" + + ll_keywords: List[str] = field(default_factory=list) + """List of low-level keywords to refine retrieval focus.""" + + conversation_history: List[dict[str, Any]] = field(default_factory=list) + """Stores past conversation history to maintain context. + Format: [{"role": "user/assistant", "content": "message"}]. + """ + + history_turns: int = 3 + """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" @dataclass class StorageNameSpace: From 4c2f13f79e27ce4790cce637a25825989a4893ef Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 9 Feb 2025 00:23:55 +0100 Subject: [PATCH 03/12] improved docs --- lightrag/lightrag.py | 97 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 20 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 347f0f4c..eff4614d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -109,38 +109,65 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: @dataclass class LightRAG: + """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" + working_dir: str = field( default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' ) - # Default not to use embedding cache - embedding_cache_config: dict = field( + """Directory where cache and temporary files are stored.""" + + embedding_cache_config: dict[str, Any] = field( default_factory=lambda: { "enabled": False, "similarity_threshold": 0.95, "use_llm_check": False, } ) + """Configuration for embedding cache. + - enabled: If True, enables caching to avoid redundant computations. + - similarity_threshold: Minimum similarity score to use cached embeddings. + - use_llm_check: If True, validates cached embeddings using an LLM. + """ + kv_storage: str = field(default="JsonKVStorage") + """Storage backend for key-value data.""" + vector_storage: str = field(default="NanoVectorDBStorage") + """Storage backend for vector embeddings.""" + graph_storage: str = field(default="NetworkXStorage") + """Storage backend for knowledge graphs.""" - # logging + # Logging current_log_level = logger.level - log_level: str = field(default=current_log_level) + log_level: int = field(default=current_log_level) + """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" + log_dir: str = field(default=os.getcwd()) + """Directory where logs are stored. Defaults to the current working directory.""" - # text chunking + # Text chunking chunk_token_size: int = 1200 + """Maximum number of tokens per text chunk when splitting documents.""" + chunk_overlap_token_size: int = 100 + """Number of overlapping tokens between consecutive text chunks to preserve context.""" + tiktoken_model_name: str = "gpt-4o-mini" + """Model name used for tokenization when chunking text.""" - # entity extraction + # Entity extraction entity_extract_max_gleaning: int = 1 - entity_summary_to_max_tokens: int = 500 + """Maximum number of entity extraction attempts for ambiguous content.""" - # node embedding + entity_summary_to_max_tokens: int = 500 + """Maximum number of tokens used for summarizing extracted entities.""" + + # Node embedding node_embedding_algorithm: str = "node2vec" - node2vec_params: dict = field( + """Algorithm used for node embedding in knowledge graphs.""" + + node2vec_params: dict[str, int] = field( default_factory=lambda: { "dimensions": 1536, "num_walks": 10, @@ -150,26 +177,56 @@ class LightRAG: "random_seed": 3, } ) + """Configuration for the node2vec embedding algorithm: + - dimensions: Number of dimensions for embeddings. + - num_walks: Number of random walks per node. + - walk_length: Number of steps per random walk. + - window_size: Context window size for training. + - iterations: Number of iterations for training. + - random_seed: Seed value for reproducibility. + """ + + embedding_func: EmbeddingFunc = None + """Function for computing text embeddings. Must be set before use.""" - # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) - embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization) embedding_batch_num: int = 32 + """Batch size for embedding computations.""" + embedding_func_max_async: int = 16 + """Maximum number of concurrent embedding function calls.""" + + # LLM Configuration + llm_model_func: callable = None + """Function for interacting with the large language model (LLM). Must be set before use.""" + + llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" + """Name of the LLM model used for generating responses.""" - # LLM - llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization) - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768")) - llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) - llm_model_kwargs: dict = field(default_factory=dict) + """Maximum number of tokens allowed per LLM response.""" + + llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) + """Maximum number of concurrent LLM calls.""" + + llm_model_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional keyword arguments passed to the LLM model function.""" + + # Storage + vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional parameters for vector database storage.""" - # storage - vector_db_storage_cls_kwargs: dict = field(default_factory=dict) namespace_prefix: str = field(default="") + """Prefix for namespacing stored data across different environments.""" enable_llm_cache: bool = True - # Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag + """Enables caching for LLM responses to avoid redundant computations.""" + enable_llm_cache_for_entity_extract: bool = True + """If True, enables caching for entity extraction steps to reduce LLM costs.""" + + # Extensions + addon_params: dict[str, Any] = field(default_factory=dict) + """Dictionary for additional parameters and extensions.""" # extension addon_params: dict[str, Any] = field(default_factory=dict) @@ -177,8 +234,8 @@ class LightRAG: convert_response_to_json ) - # Add new field for document status storage type doc_status_storage: str = field(default="JsonDocStatusStorage") + """Storage type for tracking document processing statuses.""" # Custom Chunking Function chunking_func: Callable[ From 41f76ec4592eca9873e9f94ea9a5545c51f266e8 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 9 Feb 2025 01:05:27 +0100 Subject: [PATCH 04/12] updated readme --- README.md | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 850cacd3..480b8d00 100644 --- a/README.md +++ b/README.md @@ -355,16 +355,26 @@ In order to run this experiment on low RAM GPU you should select small model and ```python class QueryParam: mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global" + """Specifies the retrieval mode: + - "local": Focuses on context-dependent information. + - "global": Utilizes global knowledge. + - "hybrid": Combines local and global retrieval methods. + - "naive": Performs a basic search without advanced techniques. + - "mix": Integrates knowledge graph and vector retrieval. + """ only_need_context: bool = False + """If True, only returns the retrieved context without generating a response.""" response_type: str = "Multiple Paragraphs" - # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" top_k: int = 60 - # Number of tokens for the original chunks. + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" max_token_for_text_unit: int = 4000 - # Number of tokens for the relationship descriptions + """Maximum number of tokens allowed for each retrieved text chunk.""" max_token_for_global_context: int = 4000 - # Number of tokens for the entity descriptions + """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" max_token_for_local_context: int = 4000 + """Maximum number of tokens allocated for entity descriptions in local retrieval.""" + ... ``` > default value of Top_k can be change by environment variables TOP_K. From 23283180c7a6e368489df53bb9fa4f209cb00154 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 9 Feb 2025 18:03:34 +0100 Subject: [PATCH 05/12] fixed type --- README.md | 2 +- lightrag/base.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 480b8d00..cf1d86aa 100644 --- a/README.md +++ b/README.md @@ -361,7 +361,7 @@ class QueryParam: - "hybrid": Combines local and global retrieval methods. - "naive": Performs a basic search without advanced techniques. - "mix": Integrates knowledge graph and vector retrieval. - """ + """ only_need_context: bool = False """If True, only returns the retrieved context without generating a response.""" response_type: str = "Multiple Paragraphs" diff --git a/lightrag/base.py b/lightrag/base.py index ae5ce92e..0e3f1dc6 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -62,13 +62,13 @@ class QueryParam: max_token_for_local_context: int = 4000 """Maximum number of tokens allocated for entity descriptions in local retrieval.""" - hl_keywords: List[str] = field(default_factory=list) + hl_keywords: list[str] = field(default_factory=list) """List of high-level keywords to prioritize in retrieval.""" - ll_keywords: List[str] = field(default_factory=list) + ll_keywords: list[str] = field(default_factory=list) """List of low-level keywords to refine retrieval focus.""" - conversation_history: List[dict[str, Any]] = field(default_factory=list) + conversation_history: list[dict[str, Any]] = field(default_factory=list) """Stores past conversation history to maintain context. Format: [{"role": "user/assistant", "content": "message"}]. """ @@ -76,6 +76,7 @@ class QueryParam: history_turns: int = 3 """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" + @dataclass class StorageNameSpace: namespace: str From 0c3b7541081f98b254699c5b7da313663119b6d0 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 11 Feb 2025 11:42:46 +0800 Subject: [PATCH 06/12] Fix bugs --- examples/lightrag_openai_demo.py | 3 ++- lightrag/operate.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index 7a43a710..c5393fc8 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -1,7 +1,7 @@ import os from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed WORKING_DIR = "./dickens" @@ -10,6 +10,7 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG( working_dir=WORKING_DIR, + embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, # llm_model_func=gpt_4o_complete ) diff --git a/lightrag/operate.py b/lightrag/operate.py index 811b4194..db7f59a5 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1504,7 +1504,7 @@ async def naive_query( use_model_func = global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, query, "default", cache_type="query" + hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) if cached_response is not None: return cached_response From 24e0f0390e0c558ff7369bd8dc4081f584b79d42 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 11 Feb 2025 11:59:28 +0800 Subject: [PATCH 07/12] Update __version__ --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index d68bded0..031502d6 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.5" +__version__ = "1.1.6" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 2d2ed19095170186eaa65324310f03354d3dccba Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 11 Feb 2025 13:28:18 +0800 Subject: [PATCH 08/12] Fix cache bugs --- .../lightrag_api_openai_compatible_demo.py | 14 +++++++++++--- lightrag/kg/jsondocstatus_impl.py | 16 ++++++++++++++++ lightrag/lightrag.py | 18 ++++++++++++++---- lightrag/llm/openai.py | 16 ++++++++-------- 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index 8173dc5b..f55b9ce1 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") +BASE_URL = int(os.environ.get("BASE_URL", "https://api.openai.com/v1")) +print(f"BASE_URL: {BASE_URL}") +API_KEY = int(os.environ.get("API_KEY", "xxxxxxxx")) +print(f"API_KEY: {API_KEY}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -36,10 +40,12 @@ async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: return await openai_complete_if_cache( - LLM_MODEL, - prompt, + model=LLM_MODEL, + prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, + base_url=BASE_URL, + api_key=API_KEY, **kwargs, ) @@ -49,8 +55,10 @@ async def llm_model_func( async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embed( - texts, + texts=texts, model=EMBEDDING_MODEL, + base_url=BASE_URL, + api_key=API_KEY, ) diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index 675cf643..fe67b830 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage): if v["status"] == DocStatus.PENDING } + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processed documents""" + return { + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == DocStatus.PROCESSED + } + + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + return { + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == DocStatus.PROCESSING + } + async def index_done_callback(self): """Save data to file after indexing""" write_json(self._data, self._file_name) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index eff4614d..5ec0dbeb 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -543,7 +543,7 @@ class LightRAG: new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids} if not new_docs: - logger.info("All documents have been processed or are duplicates") + logger.info("No new unique documents were found.") return # 4. Store status document @@ -560,15 +560,16 @@ class LightRAG: each chunk for entity and relation extraction, and updating the document status. - 1. Get all pending and failed documents + 1. Get all pending, failed, and abnormally terminated processing documents. 2. Split document content into chunks 3. Process each chunk for entity and relation extraction 4. Update the document status """ - # 1. get all pending and failed documents + # 1. Get all pending, failed, and abnormally terminated processing documents. to_process_docs: dict[str, DocProcessingStatus] = {} - # Fetch failed documents + processing_docs = await self.doc_status.get_processing_docs() + to_process_docs.update(processing_docs) failed_docs = await self.doc_status.get_failed_docs() to_process_docs.update(failed_docs) pendings_docs = await self.doc_status.get_pending_docs() @@ -599,6 +600,7 @@ class LightRAG: doc_status_id: { "status": DocStatus.PROCESSING, "updated_at": datetime.now().isoformat(), + "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, "created_at": status_doc.created_at, @@ -635,6 +637,10 @@ class LightRAG: doc_status_id: { "status": DocStatus.PROCESSED, "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, "updated_at": datetime.now().isoformat(), } } @@ -648,6 +654,10 @@ class LightRAG: doc_status_id: { "status": DocStatus.FAILED, "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, "updated_at": datetime.now().isoformat(), } } diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 535d665c..2a2ba0fd 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -103,17 +103,17 @@ async def openai_complete_if_cache( ) -> str: if history_messages is None: history_messages = [] - if api_key: - os.environ["OPENAI_API_KEY"] = api_key + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] default_headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) + AsyncOpenAI(default_headers=default_headers, api_key=api_key) if base_url is None - else AsyncOpenAI(base_url=base_url, default_headers=default_headers) + else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key) ) kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) @@ -294,17 +294,17 @@ async def openai_embed( base_url: str = None, api_key: str = None, ) -> np.ndarray: - if api_key: - os.environ["OPENAI_API_KEY"] = api_key + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] default_headers = { "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) + AsyncOpenAI(default_headers=default_headers, api_key=api_key) if base_url is None - else AsyncOpenAI(base_url=base_url, default_headers=default_headers) + else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key) ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" From 5ffbb548ad4287f60479ffc762189177f71f1c71 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 11 Feb 2025 13:32:24 +0800 Subject: [PATCH 09/12] Fix linting error --- lightrag/kg/jsondocstatus_impl.py | 2 +- lightrag/llm/openai.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index fe67b830..fad03acc 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -116,7 +116,7 @@ class JsonDocStatusStorage(DocStatusStorage): for k, v in self._data.items() if v["status"] == DocStatus.PROCESSED } - + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: """Get all processing documents""" return { diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 2a2ba0fd..e6d00377 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -113,7 +113,9 @@ async def openai_complete_if_cache( openai_async_client = ( AsyncOpenAI(default_headers=default_headers, api_key=api_key) if base_url is None - else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key) + else AsyncOpenAI( + base_url=base_url, default_headers=default_headers, api_key=api_key + ) ) kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) @@ -304,7 +306,9 @@ async def openai_embed( openai_async_client = ( AsyncOpenAI(default_headers=default_headers, api_key=api_key) if base_url is None - else AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key) + else AsyncOpenAI( + base_url=base_url, default_headers=default_headers, api_key=api_key + ) ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" From f44c83594eca199582fa326a9bdc5f86736923a4 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 11 Feb 2025 13:40:05 +0800 Subject: [PATCH 10/12] add lightrag_api_openai_compatible_demo_simplified.py --- ...g_api_openai_compatible_demo_simplified.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 examples/lightrag_api_openai_compatible_demo_simplified.py diff --git a/examples/lightrag_api_openai_compatible_demo_simplified.py b/examples/lightrag_api_openai_compatible_demo_simplified.py new file mode 100644 index 00000000..ed36fbc6 --- /dev/null +++ b/examples/lightrag_api_openai_compatible_demo_simplified.py @@ -0,0 +1,102 @@ +import os +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc +import numpy as np +import asyncio +import nest_asyncio + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +DEFAULT_RAG_DIR = "index_default" + +# Configure working directory +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") +LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini") +print(f"LLM_MODEL: {LLM_MODEL}") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small") +print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) +print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") +BASE_URL = int(os.environ.get("BASE_URL", "https://api.openai.com/v1")) +print(f"BASE_URL: {BASE_URL}") +API_KEY = int(os.environ.get("API_KEY", "xxxxxxxx")) +print(f"API_KEY: {API_KEY}") + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +# LLM model function + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + model=LLM_MODEL, + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + base_url=BASE_URL, + api_key=API_KEY, + **kwargs, + ) + + +# Embedding function + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embed( + texts=texts, + model=EMBEDDING_MODEL, + base_url=BASE_URL, + api_key=API_KEY, + ) + + +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + print(f"{embedding_dim=}") + return embedding_dim + + +# Initialize RAG instance +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=asyncio.run(get_embedding_dim()), + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, + func=embedding_func, + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) + From 43ed7386d18d92be415ec766182cb092f1d4dd51 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 11 Feb 2025 14:21:19 +0800 Subject: [PATCH 11/12] fix linting error --- examples/lightrag_api_openai_compatible_demo_simplified.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/lightrag_api_openai_compatible_demo_simplified.py b/examples/lightrag_api_openai_compatible_demo_simplified.py index ed36fbc6..9f791632 100644 --- a/examples/lightrag_api_openai_compatible_demo_simplified.py +++ b/examples/lightrag_api_openai_compatible_demo_simplified.py @@ -99,4 +99,3 @@ print( print( rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) ) - From 80a3ce2240be9d4711d8fa3c57f3006217730193 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 11 Feb 2025 16:24:22 +0800 Subject: [PATCH 12/12] fix bugs --- examples/lightrag_api_openai_compatible_demo.py | 4 ++-- examples/lightrag_api_openai_compatible_demo_simplified.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index f55b9ce1..e2d63e41 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -24,9 +24,9 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") -BASE_URL = int(os.environ.get("BASE_URL", "https://api.openai.com/v1")) +BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1") print(f"BASE_URL: {BASE_URL}") -API_KEY = int(os.environ.get("API_KEY", "xxxxxxxx")) +API_KEY = os.environ.get("API_KEY", "xxxxxxxx") print(f"API_KEY: {API_KEY}") if not os.path.exists(WORKING_DIR): diff --git a/examples/lightrag_api_openai_compatible_demo_simplified.py b/examples/lightrag_api_openai_compatible_demo_simplified.py index 9f791632..fabbb3e2 100644 --- a/examples/lightrag_api_openai_compatible_demo_simplified.py +++ b/examples/lightrag_api_openai_compatible_demo_simplified.py @@ -20,9 +20,9 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") -BASE_URL = int(os.environ.get("BASE_URL", "https://api.openai.com/v1")) +BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1") print(f"BASE_URL: {BASE_URL}") -API_KEY = int(os.environ.get("API_KEY", "xxxxxxxx")) +API_KEY = os.environ.get("API_KEY", "xxxxxxxx") print(f"API_KEY: {API_KEY}") if not os.path.exists(WORKING_DIR):