From 4bbe4f8230c9afeb560a045db0314d947a3c87b2 Mon Sep 17 00:00:00 2001 From: zhangjiawei Date: Wed, 16 Oct 2024 18:10:28 +0800 Subject: [PATCH 01/35] setup encoding modify --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 849fabfe..47222420 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ import setuptools -with open("README.md", "r") as fh: +with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() From d517ef9c209b96dc61ff7f3fb860a6f7e2b6d714 Mon Sep 17 00:00:00 2001 From: Soumil Date: Mon, 21 Oct 2024 18:34:43 +0100 Subject: [PATCH 02/35] added a class to use multiple models --- lightrag/llm.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/lightrag/llm.py b/lightrag/llm.py index be801e0c..d820766d 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -13,6 +13,8 @@ from tenacity import ( ) from transformers import AutoTokenizer, AutoModelForCausalLM import torch +from pydantic import BaseModel, Field +from typing import List, Dict, Callable, Any from .base import BaseKVStorage from .utils import compute_args_hash, wrap_embedding_func_with_attrs @@ -423,6 +425,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: return embed_text +class Model(BaseModel): + """ + This is a Pydantic model class named 'Model' that is used to define a custom language model. + + Attributes: + gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. + The function should take any argument and return a string. + kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. + This could include parameters such as the model name, API key, etc. + + Example usage: + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) + + In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. + The 'kwargs' dictionary contains the model name and API key to be passed to the function. + """ + + gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string") + kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc") + + class Config: + arbitrary_types_allowed = True + + +class MultiModel(): + """ + Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. + Could also be used for spliting across diffrent models or providers. + + Attributes: + models (List[Model]): A list of language models to be used. + + Usage example: + ```python + models = [ + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), + ] + multi_model = MultiModel(models) + rag = LightRAG( + llm_model_func=multi_model.llm_model_func + / ..other args + ) + ``` + """ + def __init__(self, models: List[Model]): + self._models = models + self._current_model = 0 + + def _next_model(self): + self._current_model = (self._current_model + 1) % len(self._models) + return self._models[self._current_model] + + async def llm_model_func( + self, + prompt, system_prompt=None, history_messages=[], **kwargs + ) -> str: + kwargs.pop("model", None) # stop from overwriting the custom model name + next_model = self._next_model() + args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs) + + return await next_model.gen_func( + **args + ) if __name__ == "__main__": import asyncio From c69a3606c6c7b48a5adcdfd8e6c5c8e8a353c63e Mon Sep 17 00:00:00 2001 From: Abyl Ikhsanov Date: Mon, 21 Oct 2024 20:40:49 +0200 Subject: [PATCH 03/35] Update llm.py --- lightrag/llm.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index be801e0c..51c48b84 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,7 +4,7 @@ import json import aioboto3 import numpy as np import ollama -from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout +from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI from tenacity import ( retry, stop_after_attempt, @@ -61,6 +61,49 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_complete_if_cache(model, + prompt, + system_prompt=None, + history_messages=[], + base_url=None, + api_key=None, + **kwargs): + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + if prompt is not None: + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + return response.choices[0].message.content class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" @@ -270,6 +313,16 @@ async def gpt_4o_mini_complete( **kwargs, ) +async def azure_openai_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await azure_openai_complete_if_cache( + "conversation-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs @@ -332,6 +385,32 @@ async def openai_embedding( ) return np.array([dp.embedding for dp in response.data]) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_embedding( + texts: list[str], + model: str = "text-embedding-3-small", + base_url: str = None, + api_key: str = None, +) -> np.ndarray: + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + + response = await openai_async_client.embeddings.create( + model=model, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data]) + # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry( From 274d0fcc92f0f77d30d34da5d9fbb4a0b9a11fd0 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 22 Oct 2024 15:16:57 +0800 Subject: [PATCH 04/35] feat(examples): support siliconcloud free API --- README.md | 1 + examples/lightrag_siliconcloud_demo.py | 79 ++++++++++++++++++++++++++ lightrag/llm.py | 48 +++++++++++++++- requirements.txt | 3 +- 4 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 examples/lightrag_siliconcloud_demo.py diff --git a/README.md b/README.md index 76535d19..87335f1f 100644 --- a/README.md +++ b/README.md @@ -629,6 +629,7 @@ def extract_queries(file_path): │ ├── lightrag_ollama_demo.py │ ├── lightrag_openai_compatible_demo.py │ ├── lightrag_openai_demo.py +│ ├── lightrag_siliconcloud_demo.py │ └── vram_management_demo.py ├── lightrag │ ├── __init__.py diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py new file mode 100644 index 00000000..e3f5e67e --- /dev/null +++ b/examples/lightrag_siliconcloud_demo.py @@ -0,0 +1,79 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "Qwen/Qwen2.5-7B-Instruct", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.siliconflow.cn/v1/", + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + model="netease-youdao/bce-embedding-base_v1", + api_key=os.getenv("UPSTAGE_API_KEY"), + max_token_size=int(512 * 1.5) + ) + + +# function test +async def test_funcs(): + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("embedding_func: ", result) + + +asyncio.run(test_funcs()) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=768, max_token_size=512, func=embedding_func + ), +) + + +with open("./book.txt") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/lightrag/llm.py b/lightrag/llm.py index be801e0c..06d75d01 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -2,8 +2,11 @@ import os import copy import json import aioboto3 +import aiohttp import numpy as np import ollama +import base64 +import struct from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from tenacity import ( retry, @@ -312,7 +315,7 @@ async def ollama_model_complete( @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), + wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) async def openai_embedding( @@ -332,6 +335,49 @@ async def openai_embedding( ) return np.array([dp.embedding for dp in response.data]) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def siliconcloud_embedding( + texts: list[str], + model: str = "netease-youdao/bce-embedding-base_v1", + base_url: str = "https://api.siliconflow.cn/v1/embeddings", + max_token_size: int = 512, + api_key: str = None, +) -> np.ndarray: + if api_key and not api_key.startswith('Bearer '): + api_key = 'Bearer ' + api_key + + headers = { + "Authorization": api_key, + "Content-Type": "application/json" + } + + truncate_texts = [text[0:max_token_size] for text in texts] + + payload = { + "model": model, + "input": truncate_texts, + "encoding_format": "base64" + } + + base64_strings = [] + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=payload) as response: + content = await response.json() + if 'code' in content: + raise ValueError(content) + base64_strings = [item['embedding'] for item in content['data']] + + embeddings = [] + for string in base64_strings: + decode_bytes = base64.b64decode(string) + n = len(decode_bytes) // 4 + float_array = struct.unpack('<' + 'f' * n, decode_bytes) + embeddings.append(float_array) + return np.array(embeddings) # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry( diff --git a/requirements.txt b/requirements.txt index 9cc5b7e9..5b3396fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ tiktoken torch transformers xxhash -pyvis \ No newline at end of file +pyvis +aiohttp \ No newline at end of file From 64124005939dceb7b2c2e52bb4f75112fba1a7ff Mon Sep 17 00:00:00 2001 From: zhangjiawei Date: Tue, 22 Oct 2024 16:01:40 +0800 Subject: [PATCH 05/35] set encoding as utf-8 when reading ./book.txt in examples --- examples/lightrag_hf_demo.py | 2 +- examples/lightrag_ollama_demo.py | 2 +- examples/lightrag_openai_compatible_demo.py | 2 +- examples/lightrag_openai_demo.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/lightrag_hf_demo.py b/examples/lightrag_hf_demo.py index 87312307..91033e50 100644 --- a/examples/lightrag_hf_demo.py +++ b/examples/lightrag_hf_demo.py @@ -30,7 +30,7 @@ rag = LightRAG( ) -with open("./book.txt") as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index c61b71c0..98f1521c 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -21,7 +21,7 @@ rag = LightRAG( ) -with open("./book.txt") as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index fbad1190..aae56821 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -55,7 +55,7 @@ rag = LightRAG( ) -with open("./book.txt") as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index a6e7f3b2..29bc75ca 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -15,7 +15,7 @@ rag = LightRAG( ) -with open("./book.txt") as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search From 7fa7bd546396f6414be4fafc937eb6a307b04404 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 23 Oct 2024 11:24:52 +0800 Subject: [PATCH 06/35] Update lightrag_siliconcloud_demo.py --- examples/lightrag_siliconcloud_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py index e3f5e67e..8be6ae7a 100644 --- a/examples/lightrag_siliconcloud_demo.py +++ b/examples/lightrag_siliconcloud_demo.py @@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray: texts, model="netease-youdao/bce-embedding-base_v1", api_key=os.getenv("UPSTAGE_API_KEY"), - max_token_size=int(512 * 1.5) + max_token_size=512 ) From e20d2a040863d58f670b6ef5eff1c67f007fd4d6 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:50:29 +0800 Subject: [PATCH 07/35] Update base.py --- lightrag/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lightrag/base.py b/lightrag/base.py index 50be4f62..cecd5edd 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -18,9 +18,13 @@ class QueryParam: mode: Literal["local", "global", "hybrid", "naive"] = "global" only_need_context: bool = False response_type: str = "Multiple Paragraphs" + # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. top_k: int = 60 + # Number of tokens for the original chunks. max_token_for_text_unit: int = 4000 + # Number of tokens for the relationship descriptions max_token_for_global_context: int = 4000 + # Number of tokens for the entity descriptions max_token_for_local_context: int = 4000 From 0bfcc00bdf2fb569cb9e191ae4bb5212b735c96c Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:53:43 +0800 Subject: [PATCH 08/35] Update README.md --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 87335f1f..42a7d5db 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,21 @@ ollama create -f Modelfile qwen2m ``` +### Query Param +```python +class QueryParam: + mode: Literal["local", "global", "hybrid", "naive"] = "global" + only_need_context: bool = False + response_type: str = "Multiple Paragraphs" + # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. + top_k: int = 60 + # Number of tokens for the original chunks. + max_token_for_text_unit: int = 4000 + # Number of tokens for the relationship descriptions + max_token_for_global_context: int = 4000 + # Number of tokens for the entity descriptions + max_token_for_local_context: int = 4000 +``` ### Batch Insert ```python From 2fb3fd25b018fc2a4c1b7a075a298453186a792b Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:54:22 +0800 Subject: [PATCH 09/35] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 42a7d5db..41cb4362 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,7 @@ ollama create -f Modelfile qwen2m ### Query Param + ```python class QueryParam: mode: Literal["local", "global", "hybrid", "naive"] = "global" @@ -220,6 +221,7 @@ class QueryParam: ``` ### Batch Insert + ```python # Batch Insert: Insert multiple texts at once rag.insert(["TEXT1", "TEXT2",...]) From 5972958e79d42e9770c4a4b2de64d577bad3bcac Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:15:23 +0800 Subject: [PATCH 10/35] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 41cb4362..dbabcb56 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,7 @@ ollama create -f Modelfile qwen2m ``` + ### Query Param ```python From 63c0283514954fc6f4c1f429cfcd4015136c750c Mon Sep 17 00:00:00 2001 From: tackhwa Date: Wed, 23 Oct 2024 15:02:28 +0800 Subject: [PATCH 11/35] fix hf bug --- examples/lightrag_siliconcloud_demo.py | 4 ++-- lightrag/llm.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py index 8be6ae7a..82cab228 100644 --- a/examples/lightrag_siliconcloud_demo.py +++ b/examples/lightrag_siliconcloud_demo.py @@ -19,7 +19,7 @@ async def llm_model_func( prompt, system_prompt=system_prompt, history_messages=history_messages, - api_key=os.getenv("UPSTAGE_API_KEY"), + api_key=os.getenv("SILICONFLOW_API_KEY"), base_url="https://api.siliconflow.cn/v1/", **kwargs, ) @@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray: return await siliconcloud_embedding( texts, model="netease-youdao/bce-embedding-base_v1", - api_key=os.getenv("UPSTAGE_API_KEY"), + api_key=os.getenv("SILICONFLOW_API_KEY"), max_token_size=512 ) diff --git a/lightrag/llm.py b/lightrag/llm.py index 67f547ea..76adec26 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,5 +1,6 @@ import os import copy +from functools import lru_cache import json import aioboto3 import aiohttp @@ -202,15 +203,22 @@ async def bedrock_complete_if_cache( return response["output"]["message"]["content"][0]["text"] +@lru_cache(maxsize=1) +def initialize_hf_model(model_name): + hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + + return hf_model, hf_tokenizer + + async def hf_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: model_name = model - hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto") + hf_model, hf_tokenizer = initialize_hf_model(model_name) if hf_tokenizer.pad_token is None: # print("use eos token") hf_tokenizer.pad_token = hf_tokenizer.eos_token - hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: From fd30ae4e4587946bea4edd8a6289cef4bf5a58e3 Mon Sep 17 00:00:00 2001 From: tackhwa Date: Wed, 23 Oct 2024 15:25:46 +0800 Subject: [PATCH 12/35] move_code --- lightrag/llm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 76adec26..4dcf535c 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -207,6 +207,8 @@ async def bedrock_complete_if_cache( def initialize_hf_model(model_name): hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True) hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + if hf_tokenizer.pad_token is None: + hf_tokenizer.pad_token = hf_tokenizer.eos_token return hf_model, hf_tokenizer @@ -216,9 +218,6 @@ async def hf_model_if_cache( ) -> str: model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) - if hf_tokenizer.pad_token is None: - # print("use eos token") - hf_tokenizer.pad_token = hf_tokenizer.eos_token hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: From f0856b918bc100f496272fb6ae4951d6f8620da4 Mon Sep 17 00:00:00 2001 From: Zhenyu Pan <120090196@link.cuhk.edu.cn> Date: Thu, 24 Oct 2024 00:58:52 +0800 Subject: [PATCH 13/35] [hotfix-#75][embedding] Fix the potential embedding problem --- examples/lightrag_openai_compatible_demo.py | 70 +++++++++++++-------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index aae56821..25d3722c 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray: ) +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + return embedding_dim + + # function test async def test_funcs(): result = await llm_model_func("How are you?") @@ -43,37 +50,46 @@ async def test_funcs(): print("embedding_func: ", result) -asyncio.run(test_funcs()) +# asyncio.run(test_funcs()) + +async def main(): + try: + embedding_dimension = await get_embedding_dim() + print(f"Detected embedding dimension: {embedding_dimension}") + + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func + ), + ) -rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=4096, max_token_size=8192, func=embedding_func - ), -) + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + # Perform naive search + print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) + ) -with open("./book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) + # Perform local search + print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) + ) -# Perform naive search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) -) + # Perform global search + print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) + ) -# Perform local search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) -) + # Perform hybrid search + print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) + ) + except Exception as e: + print(f"An error occurred: {e}") -# Perform global search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) -) - -# Perform hybrid search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) -) +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file From 516b4dfb22afec7d686e64d04534790affa22b1c Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Fri, 25 Oct 2024 14:14:36 +0800 Subject: [PATCH 14/35] Update lightrag.py --- lightrag/lightrag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5137af42..b84e22ef 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -208,7 +208,7 @@ class LightRAG: logger.info("[Entity Extraction]...") maybe_new_kg = await extract_entities( inserting_chunks, - knwoledge_graph_inst=self.chunk_entity_relation_graph, + knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, global_config=asdict(self), From ef41871b88c177584a08aba2bb9ab0dcfb612e5b Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Fri, 25 Oct 2024 14:15:31 +0800 Subject: [PATCH 15/35] Update operate.py --- lightrag/operate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index a0729cd8..b90a1ca1 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction( async def _merge_nodes_then_upsert( entity_name: str, nodes_data: list[dict], - knwoledge_graph_inst: BaseGraphStorage, + knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): already_entitiy_types = [] already_source_ids = [] already_description = [] - already_node = await knwoledge_graph_inst.get_node(entity_name) + already_node = await knowledge_graph_inst.get_node(entity_name) if already_node is not None: already_entitiy_types.append(already_node["entity_type"]) already_source_ids.extend( From 8fbbf70a8311423ad585f54389ae895d78aa0a6f Mon Sep 17 00:00:00 2001 From: Sanketh Kumar Date: Fri, 25 Oct 2024 13:23:08 +0530 Subject: [PATCH 16/35] Added linting actions for pull request --- .github/workflows/linting.yaml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/linting.yaml diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml new file mode 100644 index 00000000..32886cb0 --- /dev/null +++ b/.github/workflows/linting.yaml @@ -0,0 +1,30 @@ +name: Linting and Formatting + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + lint-and-format: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit + + - name: Run pre-commit + run: pre-commit run --all-files \ No newline at end of file From 5e3ab98d8321f436c313cd1f8d0b1d410e8b91aa Mon Sep 17 00:00:00 2001 From: Sanketh Kumar Date: Fri, 25 Oct 2024 13:32:25 +0530 Subject: [PATCH 17/35] Manually reformatted files --- .github/workflows/linting.yaml | 4 +- .gitignore | 2 +- README.md | 12 +-- examples/graph_visual_with_html.py | 6 +- examples/graph_visual_with_neo4j.py | 30 +++--- examples/lightrag_openai_compatible_demo.py | 27 ++++-- examples/lightrag_siliconcloud_demo.py | 2 +- examples/vram_management_demo.py | 36 +++++-- lightrag/llm.py | 101 ++++++++++++-------- lightrag/utils.py | 46 +++++---- requirements.txt | 4 +- 11 files changed, 175 insertions(+), 95 deletions(-) diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml index 32886cb0..7c12e0a2 100644 --- a/.github/workflows/linting.yaml +++ b/.github/workflows/linting.yaml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v2 - + - name: Set up Python uses: actions/setup-python@v2 with: @@ -27,4 +27,4 @@ jobs: pip install pre-commit - name: Run pre-commit - run: pre-commit run --all-files \ No newline at end of file + run: pre-commit run --all-files diff --git a/.gitignore b/.gitignore index 5a41ae32..fd4bd830 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ dickens/ book.txt lightrag-dev/ .idea/ -dist/ \ No newline at end of file +dist/ diff --git a/README.md b/README.md index dbabcb56..abd7ceb9 100644 --- a/README.md +++ b/README.md @@ -58,8 +58,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() -# import nest_asyncio -# nest_asyncio.apply() +# import nest_asyncio +# nest_asyncio.apply() ######### WORKING_DIR = "./dickens" @@ -157,7 +157,7 @@ rag = LightRAG(
Using Ollama Models - + * If you want to use Ollama models, you only need to set LightRAG as follows: ```python @@ -328,8 +328,8 @@ def main(): SET e.entity_type = node.entity_type, e.description = node.description, e.source_id = node.source_id, - e.displayName = node.id - REMOVE e:Entity + e.displayName = node.id + REMOVE e:Entity WITH e, node CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode RETURN count(*) @@ -382,7 +382,7 @@ def main(): except Exception as e: print(f"Error occurred: {e}") - + finally: driver.close() diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index b455e6de..e4337a54 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -3,7 +3,7 @@ from pyvis.network import Network import random # Load the GraphML file -G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml') +G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml") # Create a Pyvis network net = Network(notebook=True) @@ -13,7 +13,7 @@ net.from_nx(G) # Add colors to nodes for node in net.nodes: - node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) + node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) # Save and display the network -net.show('knowledge_graph.html') \ No newline at end of file +net.show("knowledge_graph.html") diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 22dde368..7377f21c 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "your_password" + def convert_xml_to_json(xml_path, output_path): """Converts XML file to JSON and saves the output.""" if not os.path.exists(xml_path): @@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path): json_data = xml_to_json(xml_path) if json_data: - with open(output_path, 'w', encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: json.dump(json_data, f, ensure_ascii=False, indent=2) print(f"JSON file created: {output_path}") return json_data @@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path): print("Failed to create JSON data") return None + def process_in_batches(tx, query, data, batch_size): """Process data in batches and execute the given query.""" for i in range(0, len(data), batch_size): - batch = data[i:i + batch_size] + batch = data[i : i + batch_size] tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) + def main(): # Paths - xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml') - json_file = os.path.join(WORKING_DIR, 'graph_data.json') + xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") + json_file = os.path.join(WORKING_DIR, "graph_data.json") # Convert XML to JSON json_data = convert_xml_to_json(xml_file, json_file) @@ -46,8 +49,8 @@ def main(): return # Load nodes and edges - nodes = json_data.get('nodes', []) - edges = json_data.get('edges', []) + nodes = json_data.get("nodes", []) + edges = json_data.get("edges", []) # Neo4j queries create_nodes_query = """ @@ -56,8 +59,8 @@ def main(): SET e.entity_type = node.entity_type, e.description = node.description, e.source_id = node.source_id, - e.displayName = node.id - REMOVE e:Entity + e.displayName = node.id + REMOVE e:Entity WITH e, node CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode RETURN count(*) @@ -100,19 +103,24 @@ def main(): # Execute queries in batches with driver.session() as session: # Insert nodes in batches - session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES) + session.execute_write( + process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES + ) # Insert edges in batches - session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES) + session.execute_write( + process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES + ) # Set displayName and labels session.run(set_displayname_and_labels_query) except Exception as e: print(f"Error occurred: {e}") - + finally: driver.close() + if __name__ == "__main__": main() diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index 25d3722c..2470fc00 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -52,6 +52,7 @@ async def test_funcs(): # asyncio.run(test_funcs()) + async def main(): try: embedding_dimension = await get_embedding_dim() @@ -61,35 +62,47 @@ async def main(): working_dir=WORKING_DIR, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, ), ) - with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) ) # Perform local search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) ) # Perform global search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="global"), + ) ) # Perform hybrid search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid"), + ) ) except Exception as e: print(f"An error occurred: {e}") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py index 82cab228..a73f16c5 100644 --- a/examples/lightrag_siliconcloud_demo.py +++ b/examples/lightrag_siliconcloud_demo.py @@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray: texts, model="netease-youdao/bce-embedding-base_v1", api_key=os.getenv("SILICONFLOW_API_KEY"), - max_token_size=512 + max_token_size=512, ) diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py index ec750254..c173b913 100644 --- a/examples/vram_management_demo.py +++ b/examples/vram_management_demo.py @@ -27,11 +27,12 @@ rag = LightRAG( # Read all .txt files from the TEXT_FILES_DIR directory texts = [] for filename in os.listdir(TEXT_FILES_DIR): - if filename.endswith('.txt'): + if filename.endswith(".txt"): file_path = os.path.join(TEXT_FILES_DIR, filename) - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: texts.append(file.read()) + # Batch insert texts into LightRAG with a retry mechanism def insert_texts_with_retry(rag, texts, retries=3, delay=5): for _ in range(retries): @@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5): rag.insert(texts) return except Exception as e: - print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...") + print( + f"Error occurred during insertion: {e}. Retrying in {delay} seconds..." + ) time.sleep(delay) raise RuntimeError("Failed to insert texts after multiple retries.") + insert_texts_with_retry(rag, texts) # Perform different types of queries and handle potential errors try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) except Exception as e: print(f"Error performing naive search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) except Exception as e: print(f"Error performing local search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) except Exception as e: print(f"Error performing global search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) except Exception as e: print(f"Error performing hybrid search: {e}") + # Function to clear VRAM resources def clear_vram(): os.system("sudo nvidia-smi --gpu-reset") + # Regularly clear VRAM to prevent overflow clear_vram_interval = 3600 # Clear once every hour start_time = time.time() diff --git a/lightrag/llm.py b/lightrag/llm.py index 4dcf535c..eaaa2b75 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -7,7 +7,13 @@ import aiohttp import numpy as np import ollama -from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI +from openai import ( + AsyncOpenAI, + APIConnectionError, + RateLimitError, + Timeout, + AsyncAzureOpenAI, +) import base64 import struct @@ -70,26 +76,31 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) -async def azure_openai_complete_if_cache(model, +async def azure_openai_complete_if_cache( + model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, - **kwargs): + **kwargs, +): if api_key: os.environ["AZURE_OPENAI_API_KEY"] = api_key if base_url: os.environ["AZURE_OPENAI_ENDPOINT"] = base_url - openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + ) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] @@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model, ) return response.choices[0].message.content + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" @@ -205,8 +217,12 @@ async def bedrock_complete_if_cache( @lru_cache(maxsize=1) def initialize_hf_model(model_name): - hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True) - hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + hf_tokenizer = AutoTokenizer.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) if hf_tokenizer.pad_token is None: hf_tokenizer.pad_token = hf_tokenizer.eos_token @@ -328,8 +344,9 @@ async def gpt_4o_mini_complete( **kwargs, ) + async def azure_openai_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: return await azure_openai_complete_if_cache( "conversation-4o-mini", @@ -339,6 +356,7 @@ async def azure_openai_complete( **kwargs, ) + async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -418,9 +436,11 @@ async def azure_openai_embedding( if base_url: os.environ["AZURE_OPENAI_ENDPOINT"] = base_url - openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" @@ -440,35 +460,28 @@ async def siliconcloud_embedding( max_token_size: int = 512, api_key: str = None, ) -> np.ndarray: - if api_key and not api_key.startswith('Bearer '): - api_key = 'Bearer ' + api_key + if api_key and not api_key.startswith("Bearer "): + api_key = "Bearer " + api_key - headers = { - "Authorization": api_key, - "Content-Type": "application/json" - } + headers = {"Authorization": api_key, "Content-Type": "application/json"} truncate_texts = [text[0:max_token_size] for text in texts] - payload = { - "model": model, - "input": truncate_texts, - "encoding_format": "base64" - } + payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} base64_strings = [] async with aiohttp.ClientSession() as session: async with session.post(base_url, headers=headers, json=payload) as response: content = await response.json() - if 'code' in content: + if "code" in content: raise ValueError(content) - base64_strings = [item['embedding'] for item in content['data']] - + base64_strings = [item["embedding"] for item in content["data"]] + embeddings = [] for string in base64_strings: decode_bytes = base64.b64decode(string) n = len(decode_bytes) // 4 - float_array = struct.unpack('<' + 'f' * n, decode_bytes) + float_array = struct.unpack("<" + "f" * n, decode_bytes) embeddings.append(float_array) return np.array(embeddings) @@ -563,6 +576,7 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: return embed_text + class Model(BaseModel): """ This is a Pydantic model class named 'Model' that is used to define a custom language model. @@ -580,14 +594,20 @@ class Model(BaseModel): The 'kwargs' dictionary contains the model name and API key to be passed to the function. """ - gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string") - kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc") + gen_func: Callable[[Any], str] = Field( + ..., + description="A function that generates the response from the llm. The response must be a string", + ) + kwargs: Dict[str, Any] = Field( + ..., + description="The arguments to pass to the callable function. Eg. the api key, model name, etc", + ) class Config: arbitrary_types_allowed = True -class MultiModel(): +class MultiModel: """ Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. Could also be used for spliting across diffrent models or providers. @@ -611,26 +631,31 @@ class MultiModel(): ) ``` """ + def __init__(self, models: List[Model]): self._models = models self._current_model = 0 - + def _next_model(self): self._current_model = (self._current_model + 1) % len(self._models) return self._models[self._current_model] async def llm_model_func( - self, - prompt, system_prompt=None, history_messages=[], **kwargs + self, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - kwargs.pop("model", None) # stop from overwriting the custom model name + kwargs.pop("model", None) # stop from overwriting the custom model name next_model = self._next_model() - args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs) - - return await next_model.gen_func( - **args + args = dict( + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + **next_model.kwargs, ) + return await next_model.gen_func(**args) + + if __name__ == "__main__": import asyncio diff --git a/lightrag/utils.py b/lightrag/utils.py index 9a68c16b..0da4a51a 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -185,6 +185,7 @@ def save_data_to_file(data, file_name): with open(file_name, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) + def xml_to_json(xml_file): try: tree = ET.parse(xml_file) @@ -194,31 +195,42 @@ def xml_to_json(xml_file): print(f"Root element: {root.tag}") print(f"Root attributes: {root.attrib}") - data = { - "nodes": [], - "edges": [] - } + data = {"nodes": [], "edges": []} # Use namespace - namespace = {'': 'http://graphml.graphdrawing.org/xmlns'} + namespace = {"": "http://graphml.graphdrawing.org/xmlns"} - for node in root.findall('.//node', namespace): + for node in root.findall(".//node", namespace): node_data = { - "id": node.get('id').strip('"'), - "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "", - "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "", - "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else "" + "id": node.get("id").strip('"'), + "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') + if node.find("./data[@key='d0']", namespace) is not None + else "", + "description": node.find("./data[@key='d1']", namespace).text + if node.find("./data[@key='d1']", namespace) is not None + else "", + "source_id": node.find("./data[@key='d2']", namespace).text + if node.find("./data[@key='d2']", namespace) is not None + else "", } data["nodes"].append(node_data) - for edge in root.findall('.//edge', namespace): + for edge in root.findall(".//edge", namespace): edge_data = { - "source": edge.get('source').strip('"'), - "target": edge.get('target').strip('"'), - "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0, - "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "", - "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "", - "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else "" + "source": edge.get("source").strip('"'), + "target": edge.get("target").strip('"'), + "weight": float(edge.find("./data[@key='d3']", namespace).text) + if edge.find("./data[@key='d3']", namespace) is not None + else 0.0, + "description": edge.find("./data[@key='d4']", namespace).text + if edge.find("./data[@key='d4']", namespace) is not None + else "", + "keywords": edge.find("./data[@key='d5']", namespace).text + if edge.find("./data[@key='d5']", namespace) is not None + else "", + "source_id": edge.find("./data[@key='d6']", namespace).text + if edge.find("./data[@key='d6']", namespace) is not None + else "", } data["edges"].append(edge_data) diff --git a/requirements.txt b/requirements.txt index 5b3396fb..98f32b0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ accelerate aioboto3 +aiohttp graspologic hnswlib nano-vectordb networkx ollama openai +pyvis tenacity tiktoken torch transformers xxhash -pyvis -aiohttp \ No newline at end of file From a16831616ee7b745ffdf7db3ee846c942a516f31 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Fri, 25 Oct 2024 19:25:26 +0800 Subject: [PATCH 18/35] fix Step_3_openai_compatible.py --- reproduce/Step_3_openai_compatible.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py index 2be5ea5c..5e2ef778 100644 --- a/reproduce/Step_3_openai_compatible.py +++ b/reproduce/Step_3_openai_compatible.py @@ -50,8 +50,8 @@ def extract_queries(file_path): async def process_query(query_text, rag_instance, query_param): try: - result, context = await rag_instance.aquery(query_text, param=query_param) - return {"query": query_text, "result": result, "context": context}, None + result = await rag_instance.aquery(query_text, param=query_param) + return {"query": query_text, "result": result}, None except Exception as e: return None, {"query": query_text, "error": str(e)} From 72ce8b85f4e6e8144bb3ee2d690df9368bba351c Mon Sep 17 00:00:00 2001 From: jatin009v Date: Fri, 25 Oct 2024 18:39:55 +0530 Subject: [PATCH 19/35] Key Enhancements: Error Handling: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Handled potential FileNotFoundError for README.md and requirements.txt. Checked for missing required metadata and raised an informative error if any are missing. Automated Package Discovery: Replaced packages=["lightrag"] with setuptools.find_packages() to automatically find sub-packages and exclude test or documentation directories. Additional Metadata: Added Development Status in classifiers to indicate a "Beta" release (modify based on the project's maturity). Used project_urls to link documentation, source code, and an issue tracker, which are standard for open-source projects. Compatibility: Included include_package_data=True to include additional files specified in MANIFEST.in. These changes enhance the readability, reliability, and openness of the code, making it more contributor-friendly and ensuring it’s ready for open-source distribution. --- setup.py | 74 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 21 deletions(-) diff --git a/setup.py b/setup.py index 47222420..bdf49f02 100644 --- a/setup.py +++ b/setup.py @@ -1,39 +1,71 @@ import setuptools +from pathlib import Path -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() +# Reading the long description from README.md +def read_long_description(): + try: + return Path("README.md").read_text(encoding="utf-8") + except FileNotFoundError: + return "A description of LightRAG is currently unavailable." +# Retrieving metadata from __init__.py +def retrieve_metadata(): + vars2find = ["__author__", "__version__", "__url__"] + vars2readme = {} + try: + with open("./lightrag/__init__.py") as f: + for line in f.readlines(): + for v in vars2find: + if line.startswith(v): + line = line.replace(" ", "").replace('"', "").replace("'", "").strip() + vars2readme[v] = line.split("=")[1] + except FileNotFoundError: + raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.") + + # Checking if all required variables are found + missing_vars = [v for v in vars2find if v not in vars2readme] + if missing_vars: + raise ValueError(f"Missing required metadata variables in __init__.py: {missing_vars}") + + return vars2readme -vars2find = ["__author__", "__version__", "__url__"] -vars2readme = {} -with open("./lightrag/__init__.py") as f: - for line in f.readlines(): - for v in vars2find: - if line.startswith(v): - line = line.replace(" ", "").replace('"', "").replace("'", "").strip() - vars2readme[v] = line.split("=")[1] +# Reading dependencies from requirements.txt +def read_requirements(): + deps = [] + try: + with open("./requirements.txt") as f: + deps = [line.strip() for line in f if line.strip()] + except FileNotFoundError: + print("Warning: 'requirements.txt' not found. No dependencies will be installed.") + return deps -deps = [] -with open("./requirements.txt") as f: - for line in f.readlines(): - if not line.strip(): - continue - deps.append(line.strip()) +metadata = retrieve_metadata() +long_description = read_long_description() +requirements = read_requirements() setuptools.setup( name="lightrag-hku", - url=vars2readme["__url__"], - version=vars2readme["__version__"], - author=vars2readme["__author__"], + url=metadata["__url__"], + version=metadata["__version__"], + author=metadata["__author__"], description="LightRAG: Simple and Fast Retrieval-Augmented Generation", long_description=long_description, long_description_content_type="text/markdown", - packages=["lightrag"], + packages=setuptools.find_packages(exclude=("tests*", "docs*")), # Automatically find packages classifiers=[ + "Development Status :: 4 - Beta", "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries :: Python Modules", ], python_requires=">=3.9", - install_requires=deps, + install_requires=requirements, + include_package_data=True, # Includes non-code files from MANIFEST.in + project_urls={ # Additional project metadata + "Documentation": metadata.get("__url__", ""), + "Source": metadata.get("__url__", ""), + "Tracker": f"{metadata.get('__url__', '')}/issues" if metadata.get("__url__") else "" + }, ) From 542f8835f807f2f99ddad1be83f30523a5e82996 Mon Sep 17 00:00:00 2001 From: "zhenjie.ye" Date: Sat, 26 Oct 2024 00:37:03 +0800 Subject: [PATCH 20/35] add Algorithm Flowchart --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index dbabcb56..f2f5c20e 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,10 @@ This repository hosts the code of LightRAG. The structure of this code is based - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! +## Algorithm Flowchart + + + ## Install * Install from source (Recommend) From e5cb01b16b92b2473f5dc2e7ad327b60466fbe3c Mon Sep 17 00:00:00 2001 From: "zhenjie.ye" Date: Sat, 26 Oct 2024 00:37:46 +0800 Subject: [PATCH 21/35] add Algorithm FLowchart --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f2f5c20e..0f8659b1 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## Algorithm Flowchart +![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3) ## Install From d9054c6e4f71147dafe071702512b5498224009b Mon Sep 17 00:00:00 2001 From: tackhwa Date: Sat, 26 Oct 2024 02:20:23 +0800 Subject: [PATCH 22/35] fix hf output bug --- lightrag/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 4dcf535c..692937fb 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -266,10 +266,11 @@ async def hf_model_if_cache( input_ids = hf_tokenizer( input_prompt, return_tensors="pt", padding=True, truncation=True ).to("cuda") + inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} output = hf_model.generate( **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True ) - response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True) + response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) if hashing_kv is not None: await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) return response_text From 5bfd107f5ebd38dbe7e31d4f3d6bf9d3c25389fa Mon Sep 17 00:00:00 2001 From: tackhwa <55059307+tackhwa@users.noreply.github.com> Date: Sat, 26 Oct 2024 02:42:40 +0800 Subject: [PATCH 23/35] Update token length --- lightrag/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 692937fb..ab459fc7 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -268,7 +268,7 @@ async def hf_model_if_cache( ).to("cuda") inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} output = hf_model.generate( - **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True + **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True ) response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) if hashing_kv is not None: From e3d978d331cca8c48df09618fde5d74a0711c285 Mon Sep 17 00:00:00 2001 From: Yazington Date: Sat, 26 Oct 2024 00:11:21 -0400 Subject: [PATCH 24/35] fixing bug --- lightrag/lightrag.py | 6 ++++-- lightrag/operate.py | 26 +++++++++++++------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5137af42..3004f5ed 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -85,7 +85,9 @@ class LightRAG: # LLM llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' + llm_model_name: str = ( + "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' + ) llm_model_max_token_size: int = 32768 llm_model_max_async: int = 16 @@ -208,7 +210,7 @@ class LightRAG: logger.info("[Entity Extraction]...") maybe_new_kg = await extract_entities( inserting_chunks, - knwoledge_graph_inst=self.chunk_entity_relation_graph, + knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, global_config=asdict(self), diff --git a/lightrag/operate.py b/lightrag/operate.py index a0729cd8..8a6820f5 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction( async def _merge_nodes_then_upsert( entity_name: str, nodes_data: list[dict], - knwoledge_graph_inst: BaseGraphStorage, + knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): already_entitiy_types = [] already_source_ids = [] already_description = [] - already_node = await knwoledge_graph_inst.get_node(entity_name) + already_node = await knowledge_graph_inst.get_node(entity_name) if already_node is not None: already_entitiy_types.append(already_node["entity_type"]) already_source_ids.extend( @@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert( description=description, source_id=source_id, ) - await knwoledge_graph_inst.upsert_node( + await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) @@ -172,7 +172,7 @@ async def _merge_edges_then_upsert( src_id: str, tgt_id: str, edges_data: list[dict], - knwoledge_graph_inst: BaseGraphStorage, + knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): already_weights = [] @@ -180,8 +180,8 @@ async def _merge_edges_then_upsert( already_description = [] already_keywords = [] - if await knwoledge_graph_inst.has_edge(src_id, tgt_id): - already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id) + if await knowledge_graph_inst.has_edge(src_id, tgt_id): + already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) already_weights.append(already_edge["weight"]) already_source_ids.extend( split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) @@ -202,8 +202,8 @@ async def _merge_edges_then_upsert( set([dp["source_id"] for dp in edges_data] + already_source_ids) ) for need_insert_id in [src_id, tgt_id]: - if not (await knwoledge_graph_inst.has_node(need_insert_id)): - await knwoledge_graph_inst.upsert_node( + if not (await knowledge_graph_inst.has_node(need_insert_id)): + await knowledge_graph_inst.upsert_node( need_insert_id, node_data={ "source_id": source_id, @@ -214,7 +214,7 @@ async def _merge_edges_then_upsert( description = await _handle_entity_relation_summary( (src_id, tgt_id), description, global_config ) - await knwoledge_graph_inst.upsert_edge( + await knowledge_graph_inst.upsert_edge( src_id, tgt_id, edge_data=dict( @@ -237,7 +237,7 @@ async def _merge_edges_then_upsert( async def extract_entities( chunks: dict[str, TextChunkSchema], - knwoledge_graph_inst: BaseGraphStorage, + knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict, @@ -341,13 +341,13 @@ async def extract_entities( maybe_edges[tuple(sorted(k))].extend(v) all_entities_data = await asyncio.gather( *[ - _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config) + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) for k, v in maybe_nodes.items() ] ) all_relationships_data = await asyncio.gather( *[ - _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config) + _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) for k, v in maybe_edges.items() ] ) @@ -384,7 +384,7 @@ async def extract_entities( } await relationships_vdb.upsert(data_for_vdb) - return knwoledge_graph_inst + return knowledge_graph_inst async def local_query( From f6e97c052813d216913cb00d84b668cf732bf6e3 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Sat, 26 Oct 2024 14:04:11 +0800 Subject: [PATCH 25/35] Update graph_visual_with_html.py --- examples/graph_visual_with_html.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index e4337a54..11279b3a 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -6,7 +6,7 @@ import random G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml") # Create a Pyvis network -net = Network(notebook=True) +net = Network(height="100vh", notebook=True) # Convert NetworkX graph to Pyvis network net.from_nx(G) From 4d078e948f9f85eb50cedf178cca77b04a8df74c Mon Sep 17 00:00:00 2001 From: LarFii <834462287@qq.com> Date: Sat, 26 Oct 2024 14:40:17 +0800 Subject: [PATCH 26/35] update version --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index db81e005..8e76a260 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "0.0.7" +__version__ = "0.0.8" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 02f94ab228c17122833173aa4c9825bada4a176f Mon Sep 17 00:00:00 2001 From: "zhenjie.ye" Date: Sat, 26 Oct 2024 15:56:48 +0800 Subject: [PATCH 27/35] [feat] Add API server implementation and endpoints --- README.md | 119 ++++++++++++++ .../lightrag_api_openai_compatible_demo.py | 153 ++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 examples/lightrag_api_openai_compatible_demo.py diff --git a/README.md b/README.md index 7fab9a01..d11b1691 100644 --- a/README.md +++ b/README.md @@ -397,6 +397,125 @@ if __name__ == "__main__":
+## API Server Implementation + +LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests. + +### Setting up the API Server +
+Click to expand setup instructions + +1. First, ensure you have the required dependencies: +```bash +pip install fastapi uvicorn pydantic +``` + +2. Set up your environment variables: +```bash +export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default" +``` + +3. Run the API server: +```bash +python examples/lightrag_api_openai_compatible_demo.py +``` + +The server will start on `http://0.0.0.0:8020`. +
+ +### API Endpoints + +The API server provides the following endpoints: + +#### 1. Query Endpoint +
+Click to view Query endpoint details + +- **URL:** `/query` +- **Method:** POST +- **Body:** +```json +{ + "query": "Your question here", + "mode": "hybrid" // Can be "naive", "local", "global", or "hybrid" +} +``` +- **Example:** +```bash +curl -X POST "http://127.0.0.1:8020/query" \ + -H "Content-Type: application/json" \ + -d '{"query": "What are the main themes?", "mode": "hybrid"}' +``` +
+ +#### 2. Insert Text Endpoint +
+Click to view Insert Text endpoint details + +- **URL:** `/insert` +- **Method:** POST +- **Body:** +```json +{ + "text": "Your text content here" +} +``` +- **Example:** +```bash +curl -X POST "http://127.0.0.1:8020/insert" \ + -H "Content-Type: application/json" \ + -d '{"text": "Content to be inserted into RAG"}' +``` +
+ +#### 3. Insert File Endpoint +
+Click to view Insert File endpoint details + +- **URL:** `/insert_file` +- **Method:** POST +- **Body:** +```json +{ + "file_path": "path/to/your/file.txt" +} +``` +- **Example:** +```bash +curl -X POST "http://127.0.0.1:8020/insert_file" \ + -H "Content-Type: application/json" \ + -d '{"file_path": "./book.txt"}' +``` +
+ +#### 4. Health Check Endpoint +
+Click to view Health Check endpoint details + +- **URL:** `/health` +- **Method:** GET +- **Example:** +```bash +curl -X GET "http://127.0.0.1:8020/health" +``` +
+ +### Configuration + +The API server can be configured using environment variables: +- `RAG_DIR`: Directory for storing the RAG index (default: "index_default") +- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers + +### Error Handling +
+Click to view error handling details + +The API includes comprehensive error handling: +- File not found errors (404) +- Processing errors (500) +- Supports multiple file encodings (UTF-8 and GBK) +
+ ## Evaluation ### Dataset The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py new file mode 100644 index 00000000..f8d105ea --- /dev/null +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -0,0 +1,153 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import os +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np +from typing import Optional +import asyncio +import nest_asyncio + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +DEFAULT_RAG_DIR="index_default" +app = FastAPI(title="LightRAG API", description="API for RAG operations") + +# Configure working directory +WORKING_DIR = os.environ.get('RAG_DIR', f'{DEFAULT_RAG_DIR}') +print(f"WORKING_DIR: {WORKING_DIR}") +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# LLM model function +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key='YOUR_API_KEY', + base_url="YourURL/v1", + **kwargs, + ) + +# Embedding function +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model="text-embedding-3-large", + api_key='YOUR_API_KEY', + base_url="YourURL/v1", + ) + +# Initialize RAG instance +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=3072, max_token_size=8192, func=embedding_func + ), +) + +# Data models +class QueryRequest(BaseModel): + query: str + mode: str = "hybrid" + +class InsertRequest(BaseModel): + text: str + +class InsertFileRequest(BaseModel): + file_path: str + +class Response(BaseModel): + status: str + data: Optional[str] = None + message: Optional[str] = None + +# API routes +@app.post("/query", response_model=Response) +async def query_endpoint(request: QueryRequest): + try: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: rag.query(request.query, param=QueryParam(mode=request.mode)) + ) + return Response( + status="success", + data=result + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/insert", response_model=Response) +async def insert_endpoint(request: InsertRequest): + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(request.text)) + return Response( + status="success", + message="Text inserted successfully" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/insert_file", response_model=Response) +async def insert_file(request: InsertFileRequest): + try: + # Check if file exists + if not os.path.exists(request.file_path): + raise HTTPException( + status_code=404, + detail=f"File not found: {request.file_path}" + ) + + # Read file content + try: + with open(request.file_path, 'r', encoding='utf-8') as f: + content = f.read() + except UnicodeDecodeError: + # If UTF-8 decoding fails, try other encodings + with open(request.file_path, 'r', encoding='gbk') as f: + content = f.read() + + # Insert file content + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(content)) + + return Response( + status="success", + message=f"File content from {request.file_path} inserted successfully" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8020) + +# Usage example +# To run the server, use the following command in your terminal: +# python lightrag_api_openai_compatible_demo.py + +# Example requests: +# 1. Query: +# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' + +# 2. Insert text: +# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' + +# 3. Insert file: +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' + +# 4. Health check: +# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file From 08feac942ad0de01ccbe16253d7b7a2ad35b7621 Mon Sep 17 00:00:00 2001 From: "zhenjie.ye" Date: Sat, 26 Oct 2024 16:00:30 +0800 Subject: [PATCH 28/35] Refactor code formatting in lightrag_api_openai_compatible_demo.py --- .../lightrag_api_openai_compatible_demo.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index f8d105ea..ad9560dc 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -12,7 +12,7 @@ import nest_asyncio # Apply nest_asyncio to solve event loop issues nest_asyncio.apply() -DEFAULT_RAG_DIR="index_default" +DEFAULT_RAG_DIR = "index_default" app = FastAPI(title="LightRAG API", description="API for RAG operations") # Configure working directory @@ -22,6 +22,8 @@ if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) # LLM model function + + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -36,6 +38,8 @@ async def llm_model_func( ) # Embedding function + + async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, @@ -54,29 +58,37 @@ rag = LightRAG( ) # Data models + + class QueryRequest(BaseModel): query: str mode: str = "hybrid" + class InsertRequest(BaseModel): text: str + class InsertFileRequest(BaseModel): file_path: str + class Response(BaseModel): status: str data: Optional[str] = None message: Optional[str] = None # API routes + + @app.post("/query", response_model=Response) async def query_endpoint(request: QueryRequest): try: loop = asyncio.get_event_loop() result = await loop.run_in_executor( - None, - lambda: rag.query(request.query, param=QueryParam(mode=request.mode)) + None, + lambda: rag.query( + request.query, param=QueryParam(mode=request.mode)) ) return Response( status="success", @@ -85,6 +97,7 @@ async def query_endpoint(request: QueryRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @app.post("/insert", response_model=Response) async def insert_endpoint(request: InsertRequest): try: @@ -97,6 +110,7 @@ async def insert_endpoint(request: InsertRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @app.post("/insert_file", response_model=Response) async def insert_file(request: InsertFileRequest): try: @@ -106,7 +120,7 @@ async def insert_file(request: InsertFileRequest): status_code=404, detail=f"File not found: {request.file_path}" ) - + # Read file content try: with open(request.file_path, 'r', encoding='utf-8') as f: @@ -115,11 +129,11 @@ async def insert_file(request: InsertFileRequest): # If UTF-8 decoding fails, try other encodings with open(request.file_path, 'r', encoding='gbk') as f: content = f.read() - + # Insert file content loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: rag.insert(content)) - + return Response( status="success", message=f"File content from {request.file_path} inserted successfully" @@ -127,6 +141,7 @@ async def insert_file(request: InsertFileRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @app.get("/health") async def health_check(): return {"status": "healthy"} @@ -150,4 +165,4 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file +# curl -X GET "http://127.0.0.1:8020/health" From fb84c1e5be3b6b5dc34ed96606194b93624f3900 Mon Sep 17 00:00:00 2001 From: "zhenjie.ye" Date: Sat, 26 Oct 2024 16:09:36 +0800 Subject: [PATCH 29/35] Refactor code formatting in lightrag_api_openai_compatible_demo.py --- .../lightrag_api_openai_compatible_demo.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index ad9560dc..2cd262bb 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -16,7 +16,7 @@ DEFAULT_RAG_DIR = "index_default" app = FastAPI(title="LightRAG API", description="API for RAG operations") # Configure working directory -WORKING_DIR = os.environ.get('RAG_DIR', f'{DEFAULT_RAG_DIR}') +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") print(f"WORKING_DIR: {WORKING_DIR}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -32,11 +32,12 @@ async def llm_model_func( prompt, system_prompt=system_prompt, history_messages=history_messages, - api_key='YOUR_API_KEY', + api_key="YOUR_API_KEY", base_url="YourURL/v1", **kwargs, ) + # Embedding function @@ -44,10 +45,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, model="text-embedding-3-large", - api_key='YOUR_API_KEY', + api_key="YOUR_API_KEY", base_url="YourURL/v1", ) + # Initialize RAG instance rag = LightRAG( working_dir=WORKING_DIR, @@ -78,6 +80,7 @@ class Response(BaseModel): data: Optional[str] = None message: Optional[str] = None + # API routes @@ -86,14 +89,9 @@ async def query_endpoint(request: QueryRequest): try: loop = asyncio.get_event_loop() result = await loop.run_in_executor( - None, - lambda: rag.query( - request.query, param=QueryParam(mode=request.mode)) - ) - return Response( - status="success", - data=result + None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode)) ) + return Response(status="success", data=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -103,10 +101,7 @@ async def insert_endpoint(request: InsertRequest): try: loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: rag.insert(request.text)) - return Response( - status="success", - message="Text inserted successfully" - ) + return Response(status="success", message="Text inserted successfully") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -117,17 +112,16 @@ async def insert_file(request: InsertFileRequest): # Check if file exists if not os.path.exists(request.file_path): raise HTTPException( - status_code=404, - detail=f"File not found: {request.file_path}" + status_code=404, detail=f"File not found: {request.file_path}" ) # Read file content try: - with open(request.file_path, 'r', encoding='utf-8') as f: + with open(request.file_path, "r", encoding="utf-8") as f: content = f.read() except UnicodeDecodeError: # If UTF-8 decoding fails, try other encodings - with open(request.file_path, 'r', encoding='gbk') as f: + with open(request.file_path, "r", encoding="gbk") as f: content = f.read() # Insert file content @@ -136,7 +130,7 @@ async def insert_file(request: InsertFileRequest): return Response( status="success", - message=f"File content from {request.file_path} inserted successfully" + message=f"File content from {request.file_path} inserted successfully", ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -146,8 +140,10 @@ async def insert_file(request: InsertFileRequest): async def health_check(): return {"status": "healthy"} + if __name__ == "__main__": import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8020) # Usage example From 88f4e3452839e3b1f723c9688a888c8aefeb5f21 Mon Sep 17 00:00:00 2001 From: tackhwa Date: Sat, 26 Oct 2024 16:11:15 +0800 Subject: [PATCH 30/35] support lmdeploy backend --- examples/lightrag_lmdeploy_demo.py | 74 +++++++++++++++++++++ lightrag/llm.py | 100 +++++++++++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 175 insertions(+) create mode 100644 examples/lightrag_lmdeploy_demo.py diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py new file mode 100644 index 00000000..ea7ace0e --- /dev/null +++ b/examples/lightrag_lmdeploy_demo.py @@ -0,0 +1,74 @@ +import os + +from lightrag import LightRAG, QueryParam +from lightrag.llm import lmdeploy_model_if_cache, hf_embedding +from lightrag.utils import EmbeddingFunc +from transformers import AutoModel, AutoTokenizer + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +async def lmdeploy_model_complete( + prompt=None, system_prompt=None, history_messages=[], **kwargs +) -> str: + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + return await lmdeploy_model_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + ## please specify chat_template if your local path does not follow original HF file name, + ## or model_name is a pytorch model on huggingface.co, + ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py + ## for a list of chat_template available in lmdeploy. + chat_template = "llama3", + # model_format ='awq', # if you are using awq quantization model. + # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8. + **kwargs, + ) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=lmdeploy_model_complete, + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=5000, + func=lambda texts: hf_embedding( + texts, + tokenizer=AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + embed_model=AutoModel.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + ), + ), +) + + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/lightrag/llm.py b/lightrag/llm.py index bb0d6063..028084bd 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -322,6 +322,106 @@ async def ollama_model_if_cache( return result +@lru_cache(maxsize=1) +def initialize_lmdeploy_pipeline(model, tp=1, chat_template=None, log_level='WARNING', model_format='hf', quant_policy=0): + from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig + lmdeploy_pipe = pipeline( + model_path=model, + backend_config=TurbomindEngineConfig(tp=tp, model_format=model_format, quant_policy=quant_policy), + chat_template_config=ChatTemplateConfig(model_name=chat_template) if chat_template else None, + log_level='WARNING') + return lmdeploy_pipe + + +async def lmdeploy_model_if_cache( + model, prompt, system_prompt=None, history_messages=[], + chat_template=None, model_format='hf',quant_policy=0, **kwargs +) -> str: + """ + Args: + model (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download + from ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + chat_template (str): needed when model is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, + and when the model name of local path did not match the original model name in HF. + tp (int): tensor parallel + prompt (Union[str, List[str]]): input texts to be completed. + do_preprocess (bool): whether pre-process the messages. Default to + True, which means chat_template will be applied. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. + Default to be False, which means greedy decoding will be applied. + """ + try: + import lmdeploy + from lmdeploy import version_info, GenerationConfig + except: + raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") + + kwargs.pop("response_format", None) + max_new_tokens = kwargs.pop("max_tokens", 512) + tp = kwargs.pop('tp', 1) + skip_special_tokens = kwargs.pop('skip_special_tokens', False) + do_preprocess = kwargs.pop('do_preprocess', True) + do_sample = kwargs.pop('do_sample', False) + gen_params = kwargs + + version = version_info + if do_sample is not None and version < (0, 6, 0): + raise RuntimeError( + '`do_sample` parameter is not supported by lmdeploy until ' + f'v0.6.0, but currently using lmdeloy {lmdeploy.__version__}') + else: + do_sample = True + gen_params.update(do_sample=do_sample) + + lmdeploy_pipe = initialize_lmdeploy_pipeline( + model=model, + tp=tp, + chat_template=chat_template, + model_format=model_format, + quant_policy=quant_policy, + log_level='WARNING') + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + gen_config = GenerationConfig( + skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params) + + response = "" + async for res in lmdeploy_pipe.generate(messages, gen_config=gen_config, + do_preprocess=do_preprocess, stream_response=False, session_id=1): + response += res.response + + if hashing_kv is not None: + await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) + return response + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: diff --git a/requirements.txt b/requirements.txt index 98f32b0a..6b0e025a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ tiktoken torch transformers xxhash +# lmdeploy[all] From f71e389d5b2772b1cc381dada644b9118334d9dc Mon Sep 17 00:00:00 2001 From: "zhenjie.ye" Date: Sat, 26 Oct 2024 16:12:10 +0800 Subject: [PATCH 31/35] Refactor code formatting in lightrag_api_openai_compatible_demo.py --- lightrag/lightrag.py | 4 +--- lightrag/llm.py | 4 +++- setup.py | 31 ++++++++++++++++++++++++------- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3004f5ed..b84e22ef 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -85,9 +85,7 @@ class LightRAG: # LLM llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# - llm_model_name: str = ( - "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' - ) + llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = 32768 llm_model_max_async: int = 16 diff --git a/lightrag/llm.py b/lightrag/llm.py index bb0d6063..fd6b72d6 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -286,7 +286,9 @@ async def hf_model_if_cache( output = hf_model.generate( **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True ) - response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) + response_text = hf_tokenizer.decode( + output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True + ) if hashing_kv is not None: await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) return response_text diff --git a/setup.py b/setup.py index bdf49f02..1b1f65f0 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ import setuptools from pathlib import Path + # Reading the long description from README.md def read_long_description(): try: @@ -8,6 +9,7 @@ def read_long_description(): except FileNotFoundError: return "A description of LightRAG is currently unavailable." + # Retrieving metadata from __init__.py def retrieve_metadata(): vars2find = ["__author__", "__version__", "__url__"] @@ -17,18 +19,26 @@ def retrieve_metadata(): for line in f.readlines(): for v in vars2find: if line.startswith(v): - line = line.replace(" ", "").replace('"', "").replace("'", "").strip() + line = ( + line.replace(" ", "") + .replace('"', "") + .replace("'", "") + .strip() + ) vars2readme[v] = line.split("=")[1] except FileNotFoundError: raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.") - + # Checking if all required variables are found missing_vars = [v for v in vars2find if v not in vars2readme] if missing_vars: - raise ValueError(f"Missing required metadata variables in __init__.py: {missing_vars}") - + raise ValueError( + f"Missing required metadata variables in __init__.py: {missing_vars}" + ) + return vars2readme + # Reading dependencies from requirements.txt def read_requirements(): deps = [] @@ -36,9 +46,12 @@ def read_requirements(): with open("./requirements.txt") as f: deps = [line.strip() for line in f if line.strip()] except FileNotFoundError: - print("Warning: 'requirements.txt' not found. No dependencies will be installed.") + print( + "Warning: 'requirements.txt' not found. No dependencies will be installed." + ) return deps + metadata = retrieve_metadata() long_description = read_long_description() requirements = read_requirements() @@ -51,7 +64,9 @@ setuptools.setup( description="LightRAG: Simple and Fast Retrieval-Augmented Generation", long_description=long_description, long_description_content_type="text/markdown", - packages=setuptools.find_packages(exclude=("tests*", "docs*")), # Automatically find packages + packages=setuptools.find_packages( + exclude=("tests*", "docs*") + ), # Automatically find packages classifiers=[ "Development Status :: 4 - Beta", "Programming Language :: Python :: 3", @@ -66,6 +81,8 @@ setuptools.setup( project_urls={ # Additional project metadata "Documentation": metadata.get("__url__", ""), "Source": metadata.get("__url__", ""), - "Tracker": f"{metadata.get('__url__', '')}/issues" if metadata.get("__url__") else "" + "Tracker": f"{metadata.get('__url__', '')}/issues" + if metadata.get("__url__") + else "", }, ) From 2120a6dabb320f8a3a5f9388afda65b80d4093c8 Mon Sep 17 00:00:00 2001 From: tackhwa Date: Sat, 26 Oct 2024 16:13:18 +0800 Subject: [PATCH 32/35] pre-commit --- examples/lightrag_lmdeploy_demo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py index ea7ace0e..aeb96f71 100644 --- a/examples/lightrag_lmdeploy_demo.py +++ b/examples/lightrag_lmdeploy_demo.py @@ -10,10 +10,11 @@ WORKING_DIR = "./dickens" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def lmdeploy_model_complete( prompt=None, system_prompt=None, history_messages=[], **kwargs ) -> str: - model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await lmdeploy_model_if_cache( model_name, prompt, @@ -23,7 +24,7 @@ async def lmdeploy_model_complete( ## or model_name is a pytorch model on huggingface.co, ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py ## for a list of chat_template available in lmdeploy. - chat_template = "llama3", + chat_template="llama3", # model_format ='awq', # if you are using awq quantization model. # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8. **kwargs, @@ -33,7 +34,7 @@ async def lmdeploy_model_complete( rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=lmdeploy_model_complete, - llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model embedding_func=EmbeddingFunc( embedding_dim=384, max_token_size=5000, From 81d5b904fbf06379047ba717869af111d2041333 Mon Sep 17 00:00:00 2001 From: tackhwa Date: Sat, 26 Oct 2024 16:24:35 +0800 Subject: [PATCH 33/35] update do_preprocess --- lightrag/llm.py | 77 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 23 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 028084bd..d86886ea 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -286,7 +286,9 @@ async def hf_model_if_cache( output = hf_model.generate( **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True ) - response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) + response_text = hf_tokenizer.decode( + output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True + ) if hashing_kv is not None: await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) return response_text @@ -323,19 +325,38 @@ async def ollama_model_if_cache( @lru_cache(maxsize=1) -def initialize_lmdeploy_pipeline(model, tp=1, chat_template=None, log_level='WARNING', model_format='hf', quant_policy=0): +def initialize_lmdeploy_pipeline( + model, + tp=1, + chat_template=None, + log_level="WARNING", + model_format="hf", + quant_policy=0, +): from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig + lmdeploy_pipe = pipeline( model_path=model, - backend_config=TurbomindEngineConfig(tp=tp, model_format=model_format, quant_policy=quant_policy), - chat_template_config=ChatTemplateConfig(model_name=chat_template) if chat_template else None, - log_level='WARNING') + backend_config=TurbomindEngineConfig( + tp=tp, model_format=model_format, quant_policy=quant_policy + ), + chat_template_config=ChatTemplateConfig(model_name=chat_template) + if chat_template + else None, + log_level="WARNING", + ) return lmdeploy_pipe async def lmdeploy_model_if_cache( - model, prompt, system_prompt=None, history_messages=[], - chat_template=None, model_format='hf',quant_policy=0, **kwargs + model, + prompt, + system_prompt=None, + history_messages=[], + chat_template=None, + model_format="hf", + quant_policy=0, + **kwargs, ) -> str: """ Args: @@ -354,36 +375,37 @@ async def lmdeploy_model_if_cache( and so on. chat_template (str): needed when model is a pytorch model on huggingface.co, such as "internlm-chat-7b", - "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, and when the model name of local path did not match the original model name in HF. tp (int): tensor parallel prompt (Union[str, List[str]]): input texts to be completed. do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be False. - do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. + in the decoding. Default to be True. + do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. Default to be False, which means greedy decoding will be applied. """ try: import lmdeploy from lmdeploy import version_info, GenerationConfig - except: + except Exception: raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") - + kwargs.pop("response_format", None) max_new_tokens = kwargs.pop("max_tokens", 512) - tp = kwargs.pop('tp', 1) - skip_special_tokens = kwargs.pop('skip_special_tokens', False) - do_preprocess = kwargs.pop('do_preprocess', True) - do_sample = kwargs.pop('do_sample', False) + tp = kwargs.pop("tp", 1) + skip_special_tokens = kwargs.pop("skip_special_tokens", True) + do_preprocess = kwargs.pop("do_preprocess", True) + do_sample = kwargs.pop("do_sample", False) gen_params = kwargs - + version = version_info if do_sample is not None and version < (0, 6, 0): raise RuntimeError( - '`do_sample` parameter is not supported by lmdeploy until ' - f'v0.6.0, but currently using lmdeloy {lmdeploy.__version__}') + "`do_sample` parameter is not supported by lmdeploy until " + f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" + ) else: do_sample = True gen_params.update(do_sample=do_sample) @@ -394,7 +416,8 @@ async def lmdeploy_model_if_cache( chat_template=chat_template, model_format=model_format, quant_policy=quant_policy, - log_level='WARNING') + log_level="WARNING", + ) messages = [] if system_prompt: @@ -410,11 +433,19 @@ async def lmdeploy_model_if_cache( return if_cache_return["return"] gen_config = GenerationConfig( - skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params) + skip_special_tokens=skip_special_tokens, + max_new_tokens=max_new_tokens, + **gen_params, + ) response = "" - async for res in lmdeploy_pipe.generate(messages, gen_config=gen_config, - do_preprocess=do_preprocess, stream_response=False, session_id=1): + async for res in lmdeploy_pipe.generate( + messages, + gen_config=gen_config, + do_preprocess=do_preprocess, + stream_response=False, + session_id=1, + ): response += res.response if hashing_kv is not None: From 6fe468b4f45ea46e8e7fe0d5034baf083b155467 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:59:40 +0800 Subject: [PATCH 34/35] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d11b1691..bfdf920f 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ - +

From b2021216e90644e515e8717fc32493b0bd17b54b Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:08:41 +0800 Subject: [PATCH 35/35] Update README.md --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bfdf920f..15696b57 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,15 @@ rag.insert(["TEXT1", "TEXT2",...]) ```python # Incremental Insert: Insert new documents into an existing LightRAG instance -rag = LightRAG(working_dir="./dickens") +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), +) with open("./newText.txt") as f: rag.insert(f.read())