From f59b399656f282aabe140c1e58c7110b827b749e Mon Sep 17 00:00:00 2001
From: zhangjiawei
Date: Wed, 16 Oct 2024 18:10:28 +0800
Subject: [PATCH 01/35] setup encoding modify
---
setup.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/setup.py b/setup.py
index 849fabfe..47222420 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
import setuptools
-with open("README.md", "r") as fh:
+with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
From e5ab24bad45a8ed8c71ff2a2834a14d326b5f8e9 Mon Sep 17 00:00:00 2001
From: Soumil
Date: Mon, 21 Oct 2024 18:34:43 +0100
Subject: [PATCH 02/35] added a class to use multiple models
---
lightrag/llm.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 69 insertions(+)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index be801e0c..d820766d 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -13,6 +13,8 @@ from tenacity import (
)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
+from pydantic import BaseModel, Field
+from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
@@ -423,6 +425,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
return embed_text
+class Model(BaseModel):
+ """
+ This is a Pydantic model class named 'Model' that is used to define a custom language model.
+
+ Attributes:
+ gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
+ The function should take any argument and return a string.
+ kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
+ This could include parameters such as the model name, API key, etc.
+
+ Example usage:
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
+
+ In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
+ The 'kwargs' dictionary contains the model name and API key to be passed to the function.
+ """
+
+ gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
+ kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
+
+ class Config:
+ arbitrary_types_allowed = True
+
+
+class MultiModel():
+ """
+ Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
+ Could also be used for spliting across diffrent models or providers.
+
+ Attributes:
+ models (List[Model]): A list of language models to be used.
+
+ Usage example:
+ ```python
+ models = [
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
+ ]
+ multi_model = MultiModel(models)
+ rag = LightRAG(
+ llm_model_func=multi_model.llm_model_func
+ / ..other args
+ )
+ ```
+ """
+ def __init__(self, models: List[Model]):
+ self._models = models
+ self._current_model = 0
+
+ def _next_model(self):
+ self._current_model = (self._current_model + 1) % len(self._models)
+ return self._models[self._current_model]
+
+ async def llm_model_func(
+ self,
+ prompt, system_prompt=None, history_messages=[], **kwargs
+ ) -> str:
+ kwargs.pop("model", None) # stop from overwriting the custom model name
+ next_model = self._next_model()
+ args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
+
+ return await next_model.gen_func(
+ **args
+ )
if __name__ == "__main__":
import asyncio
From 1dd927eb9d5fd8c8af0391a0c3de3e0e2fc1e2b0 Mon Sep 17 00:00:00 2001
From: Abyl Ikhsanov
Date: Mon, 21 Oct 2024 20:40:49 +0200
Subject: [PATCH 03/35] Update llm.py
---
lightrag/llm.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 80 insertions(+), 1 deletion(-)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index be801e0c..51c48b84 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -4,7 +4,7 @@ import json
import aioboto3
import numpy as np
import ollama
-from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
+from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
from tenacity import (
retry,
stop_after_attempt,
@@ -61,6 +61,49 @@ async def openai_complete_if_cache(
)
return response.choices[0].message.content
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def azure_openai_complete_if_cache(model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ base_url=None,
+ api_key=None,
+ **kwargs):
+ if api_key:
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
+ if base_url:
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
+
+ openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
+
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ if prompt is not None:
+ messages.append({"role": "user", "content": prompt})
+ if hashing_kv is not None:
+ args_hash = compute_args_hash(model, messages)
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
+ if if_cache_return is not None:
+ return if_cache_return["return"]
+
+ response = await openai_async_client.chat.completions.create(
+ model=model, messages=messages, **kwargs
+ )
+
+ if hashing_kv is not None:
+ await hashing_kv.upsert(
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
+ )
+ return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@@ -270,6 +313,16 @@ async def gpt_4o_mini_complete(
**kwargs,
)
+async def azure_openai_complete(
+ prompt, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ return await azure_openai_complete_if_cache(
+ "conversation-4o-mini",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
@@ -332,6 +385,32 @@ async def openai_embedding(
)
return np.array([dp.embedding for dp in response.data])
+@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def azure_openai_embedding(
+ texts: list[str],
+ model: str = "text-embedding-3-small",
+ base_url: str = None,
+ api_key: str = None,
+) -> np.ndarray:
+ if api_key:
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
+ if base_url:
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
+
+ openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
+
+ response = await openai_async_client.embeddings.create(
+ model=model, input=texts, encoding_format="float"
+ )
+ return np.array([dp.embedding for dp in response.data])
+
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
From 1ef973c7fc0d9126954dfdc31ecd53643b6c18cd Mon Sep 17 00:00:00 2001
From: tpoisonooo
Date: Tue, 22 Oct 2024 15:16:57 +0800
Subject: [PATCH 04/35] feat(examples): support siliconcloud free API
---
README.md | 1 +
examples/lightrag_siliconcloud_demo.py | 79 ++++++++++++++++++++++++++
lightrag/llm.py | 48 +++++++++++++++-
requirements.txt | 3 +-
4 files changed, 129 insertions(+), 2 deletions(-)
create mode 100644 examples/lightrag_siliconcloud_demo.py
diff --git a/README.md b/README.md
index 76535d19..87335f1f 100644
--- a/README.md
+++ b/README.md
@@ -629,6 +629,7 @@ def extract_queries(file_path):
│ ├── lightrag_ollama_demo.py
│ ├── lightrag_openai_compatible_demo.py
│ ├── lightrag_openai_demo.py
+│ ├── lightrag_siliconcloud_demo.py
│ └── vram_management_demo.py
├── lightrag
│ ├── __init__.py
diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py
new file mode 100644
index 00000000..e3f5e67e
--- /dev/null
+++ b/examples/lightrag_siliconcloud_demo.py
@@ -0,0 +1,79 @@
+import os
+import asyncio
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
+from lightrag.utils import EmbeddingFunc
+import numpy as np
+
+WORKING_DIR = "./dickens"
+
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+
+async def llm_model_func(
+ prompt, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ return await openai_complete_if_cache(
+ "Qwen/Qwen2.5-7B-Instruct",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ api_key=os.getenv("UPSTAGE_API_KEY"),
+ base_url="https://api.siliconflow.cn/v1/",
+ **kwargs,
+ )
+
+
+async def embedding_func(texts: list[str]) -> np.ndarray:
+ return await siliconcloud_embedding(
+ texts,
+ model="netease-youdao/bce-embedding-base_v1",
+ api_key=os.getenv("UPSTAGE_API_KEY"),
+ max_token_size=int(512 * 1.5)
+ )
+
+
+# function test
+async def test_funcs():
+ result = await llm_model_func("How are you?")
+ print("llm_model_func: ", result)
+
+ result = await embedding_func(["How are you?"])
+ print("embedding_func: ", result)
+
+
+asyncio.run(test_funcs())
+
+
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=768, max_token_size=512, func=embedding_func
+ ),
+)
+
+
+with open("./book.txt") as f:
+ rag.insert(f.read())
+
+# Perform naive search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
+
+# Perform local search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
+
+# Perform global search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
+
+# Perform hybrid search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index be801e0c..06d75d01 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -2,8 +2,11 @@ import os
import copy
import json
import aioboto3
+import aiohttp
import numpy as np
import ollama
+import base64
+import struct
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
from tenacity import (
retry,
@@ -312,7 +315,7 @@ async def ollama_model_complete(
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_embedding(
@@ -332,6 +335,49 @@ async def openai_embedding(
)
return np.array([dp.embedding for dp in response.data])
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def siliconcloud_embedding(
+ texts: list[str],
+ model: str = "netease-youdao/bce-embedding-base_v1",
+ base_url: str = "https://api.siliconflow.cn/v1/embeddings",
+ max_token_size: int = 512,
+ api_key: str = None,
+) -> np.ndarray:
+ if api_key and not api_key.startswith('Bearer '):
+ api_key = 'Bearer ' + api_key
+
+ headers = {
+ "Authorization": api_key,
+ "Content-Type": "application/json"
+ }
+
+ truncate_texts = [text[0:max_token_size] for text in texts]
+
+ payload = {
+ "model": model,
+ "input": truncate_texts,
+ "encoding_format": "base64"
+ }
+
+ base64_strings = []
+ async with aiohttp.ClientSession() as session:
+ async with session.post(base_url, headers=headers, json=payload) as response:
+ content = await response.json()
+ if 'code' in content:
+ raise ValueError(content)
+ base64_strings = [item['embedding'] for item in content['data']]
+
+ embeddings = []
+ for string in base64_strings:
+ decode_bytes = base64.b64decode(string)
+ n = len(decode_bytes) // 4
+ float_array = struct.unpack('<' + 'f' * n, decode_bytes)
+ embeddings.append(float_array)
+ return np.array(embeddings)
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
diff --git a/requirements.txt b/requirements.txt
index 9cc5b7e9..5b3396fb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,4 +11,5 @@ tiktoken
torch
transformers
xxhash
-pyvis
\ No newline at end of file
+pyvis
+aiohttp
\ No newline at end of file
From a097b481704bd89adb82c001aff46eec1e8945c7 Mon Sep 17 00:00:00 2001
From: zhangjiawei
Date: Tue, 22 Oct 2024 16:01:40 +0800
Subject: [PATCH 05/35] set encoding as utf-8 when reading ./book.txt in
examples
---
examples/lightrag_hf_demo.py | 2 +-
examples/lightrag_ollama_demo.py | 2 +-
examples/lightrag_openai_compatible_demo.py | 2 +-
examples/lightrag_openai_demo.py | 2 +-
4 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/examples/lightrag_hf_demo.py b/examples/lightrag_hf_demo.py
index 87312307..91033e50 100644
--- a/examples/lightrag_hf_demo.py
+++ b/examples/lightrag_hf_demo.py
@@ -30,7 +30,7 @@ rag = LightRAG(
)
-with open("./book.txt") as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py
index c61b71c0..98f1521c 100644
--- a/examples/lightrag_ollama_demo.py
+++ b/examples/lightrag_ollama_demo.py
@@ -21,7 +21,7 @@ rag = LightRAG(
)
-with open("./book.txt") as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py
index fbad1190..aae56821 100644
--- a/examples/lightrag_openai_compatible_demo.py
+++ b/examples/lightrag_openai_compatible_demo.py
@@ -55,7 +55,7 @@ rag = LightRAG(
)
-with open("./book.txt") as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py
index a6e7f3b2..29bc75ca 100644
--- a/examples/lightrag_openai_demo.py
+++ b/examples/lightrag_openai_demo.py
@@ -15,7 +15,7 @@ rag = LightRAG(
)
-with open("./book.txt") as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
From ed59b1e1439d6989d2de914b33714d22cfb05970 Mon Sep 17 00:00:00 2001
From: tpoisonooo
Date: Wed, 23 Oct 2024 11:24:52 +0800
Subject: [PATCH 06/35] Update lightrag_siliconcloud_demo.py
---
examples/lightrag_siliconcloud_demo.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py
index e3f5e67e..8be6ae7a 100644
--- a/examples/lightrag_siliconcloud_demo.py
+++ b/examples/lightrag_siliconcloud_demo.py
@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
texts,
model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("UPSTAGE_API_KEY"),
- max_token_size=int(512 * 1.5)
+ max_token_size=512
)
From 43d7759dcb21be5bff65d6982069b74042617347 Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Wed, 23 Oct 2024 11:50:29 +0800
Subject: [PATCH 07/35] Update base.py
---
lightrag/base.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/lightrag/base.py b/lightrag/base.py
index 50be4f62..cecd5edd 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -18,9 +18,13 @@ class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
+ # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
+ # Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
+ # Number of tokens for the relationship descriptions
max_token_for_global_context: int = 4000
+ # Number of tokens for the entity descriptions
max_token_for_local_context: int = 4000
From 4d62129fbaec083d0dde4d3292c5eb5e575907db Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Wed, 23 Oct 2024 11:53:43 +0800
Subject: [PATCH 08/35] Update README.md
---
README.md | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/README.md b/README.md
index 87335f1f..42a7d5db 100644
--- a/README.md
+++ b/README.md
@@ -203,6 +203,21 @@ ollama create -f Modelfile qwen2m
```
+### Query Param
+```python
+class QueryParam:
+ mode: Literal["local", "global", "hybrid", "naive"] = "global"
+ only_need_context: bool = False
+ response_type: str = "Multiple Paragraphs"
+ # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
+ top_k: int = 60
+ # Number of tokens for the original chunks.
+ max_token_for_text_unit: int = 4000
+ # Number of tokens for the relationship descriptions
+ max_token_for_global_context: int = 4000
+ # Number of tokens for the entity descriptions
+ max_token_for_local_context: int = 4000
+```
### Batch Insert
```python
From 7d5da418ffee202905c44c6d115c320eb20480e2 Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Wed, 23 Oct 2024 11:54:22 +0800
Subject: [PATCH 09/35] Update README.md
---
README.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/README.md b/README.md
index 42a7d5db..41cb4362 100644
--- a/README.md
+++ b/README.md
@@ -204,6 +204,7 @@ ollama create -f Modelfile qwen2m
### Query Param
+
```python
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global"
@@ -220,6 +221,7 @@ class QueryParam:
```
### Batch Insert
+
```python
# Batch Insert: Insert multiple texts at once
rag.insert(["TEXT1", "TEXT2",...])
From ec9acd6824004a01eba3a0e9c2448ea75ea7c996 Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Wed, 23 Oct 2024 12:15:23 +0800
Subject: [PATCH 10/35] Update README.md
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index 41cb4362..dbabcb56 100644
--- a/README.md
+++ b/README.md
@@ -203,6 +203,7 @@ ollama create -f Modelfile qwen2m
```
+
### Query Param
```python
From dfec83de1db29d485383b7397a9d3077863648f1 Mon Sep 17 00:00:00 2001
From: tackhwa
Date: Wed, 23 Oct 2024 15:02:28 +0800
Subject: [PATCH 11/35] fix hf bug
---
examples/lightrag_siliconcloud_demo.py | 4 ++--
lightrag/llm.py | 12 ++++++++++--
2 files changed, 12 insertions(+), 4 deletions(-)
diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py
index 8be6ae7a..82cab228 100644
--- a/examples/lightrag_siliconcloud_demo.py
+++ b/examples/lightrag_siliconcloud_demo.py
@@ -19,7 +19,7 @@ async def llm_model_func(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
- api_key=os.getenv("UPSTAGE_API_KEY"),
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
base_url="https://api.siliconflow.cn/v1/",
**kwargs,
)
@@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
model="netease-youdao/bce-embedding-base_v1",
- api_key=os.getenv("UPSTAGE_API_KEY"),
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
max_token_size=512
)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 67f547ea..76adec26 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -1,5 +1,6 @@
import os
import copy
+from functools import lru_cache
import json
import aioboto3
import aiohttp
@@ -202,15 +203,22 @@ async def bedrock_complete_if_cache(
return response["output"]["message"]["content"][0]["text"]
+@lru_cache(maxsize=1)
+def initialize_hf_model(model_name):
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
+
+ return hf_model, hf_tokenizer
+
+
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = model
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
+ hf_model, hf_tokenizer = initialize_hf_model(model_name)
if hf_tokenizer.pad_token is None:
# print("use eos token")
hf_tokenizer.pad_token = hf_tokenizer.eos_token
- hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
From f96ffad62f652c1fa6b99e39473fa788a70f25bd Mon Sep 17 00:00:00 2001
From: tackhwa
Date: Wed, 23 Oct 2024 15:25:46 +0800
Subject: [PATCH 12/35] move_code
---
lightrag/llm.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 76adec26..4dcf535c 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -207,6 +207,8 @@ async def bedrock_complete_if_cache(
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
+ if hf_tokenizer.pad_token is None:
+ hf_tokenizer.pad_token = hf_tokenizer.eos_token
return hf_model, hf_tokenizer
@@ -216,9 +218,6 @@ async def hf_model_if_cache(
) -> str:
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
- if hf_tokenizer.pad_token is None:
- # print("use eos token")
- hf_tokenizer.pad_token = hf_tokenizer.eos_token
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
From bd160013fc3053066f503f27fd26ab60a6ef2d95 Mon Sep 17 00:00:00 2001
From: Zhenyu Pan <120090196@link.cuhk.edu.cn>
Date: Thu, 24 Oct 2024 00:58:52 +0800
Subject: [PATCH 13/35] [hotfix-#75][embedding] Fix the potential embedding
problem
---
examples/lightrag_openai_compatible_demo.py | 70 +++++++++++++--------
1 file changed, 43 insertions(+), 27 deletions(-)
diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py
index aae56821..25d3722c 100644
--- a/examples/lightrag_openai_compatible_demo.py
+++ b/examples/lightrag_openai_compatible_demo.py
@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
)
+async def get_embedding_dim():
+ test_text = ["This is a test sentence."]
+ embedding = await embedding_func(test_text)
+ embedding_dim = embedding.shape[1]
+ return embedding_dim
+
+
# function test
async def test_funcs():
result = await llm_model_func("How are you?")
@@ -43,37 +50,46 @@ async def test_funcs():
print("embedding_func: ", result)
-asyncio.run(test_funcs())
+# asyncio.run(test_funcs())
+
+async def main():
+ try:
+ embedding_dimension = await get_embedding_dim()
+ print(f"Detected embedding dimension: {embedding_dimension}")
+
+ rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func
+ ),
+ )
-rag = LightRAG(
- working_dir=WORKING_DIR,
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=4096, max_token_size=8192, func=embedding_func
- ),
-)
+ with open("./book.txt", "r", encoding="utf-8") as f:
+ rag.insert(f.read())
+ # Perform naive search
+ print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+ )
-with open("./book.txt", "r", encoding="utf-8") as f:
- rag.insert(f.read())
+ # Perform local search
+ print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+ )
-# Perform naive search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
-)
+ # Perform global search
+ print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+ )
-# Perform local search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
-)
+ # Perform hybrid search
+ print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+ )
+ except Exception as e:
+ print(f"An error occurred: {e}")
-# Perform global search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
-)
-
-# Perform hybrid search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
-)
+if __name__ == "__main__":
+ asyncio.run(main())
\ No newline at end of file
From 060bb1cc5998e23fd93da42ab28847e00d17b73e Mon Sep 17 00:00:00 2001
From: tpoisonooo
Date: Fri, 25 Oct 2024 14:14:36 +0800
Subject: [PATCH 14/35] Update lightrag.py
---
lightrag/lightrag.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 5137af42..b84e22ef 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -208,7 +208,7 @@ class LightRAG:
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
- knwoledge_graph_inst=self.chunk_entity_relation_graph,
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
From 88f2c49b8627262155416ded908ac388ce121852 Mon Sep 17 00:00:00 2001
From: tpoisonooo
Date: Fri, 25 Oct 2024 14:15:31 +0800
Subject: [PATCH 15/35] Update operate.py
---
lightrag/operate.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/lightrag/operate.py b/lightrag/operate.py
index a0729cd8..b90a1ca1 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
- already_node = await knwoledge_graph_inst.get_node(entity_name)
+ already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
From 2401e21ef20f4290c555ff3ae358e01045dbffa6 Mon Sep 17 00:00:00 2001
From: Sanketh Kumar
Date: Fri, 25 Oct 2024 13:23:08 +0530
Subject: [PATCH 16/35] Added linting actions for pull request
---
.github/workflows/linting.yaml | 30 ++++++++++++++++++++++++++++++
1 file changed, 30 insertions(+)
create mode 100644 .github/workflows/linting.yaml
diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml
new file mode 100644
index 00000000..32886cb0
--- /dev/null
+++ b/.github/workflows/linting.yaml
@@ -0,0 +1,30 @@
+name: Linting and Formatting
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ lint-and-format:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.x'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install pre-commit
+
+ - name: Run pre-commit
+ run: pre-commit run --all-files
\ No newline at end of file
From a157e8e0a24599a296291072beb90fe96c445e6f Mon Sep 17 00:00:00 2001
From: Sanketh Kumar
Date: Fri, 25 Oct 2024 13:32:25 +0530
Subject: [PATCH 17/35] Manually reformatted files
---
.github/workflows/linting.yaml | 4 +-
.gitignore | 2 +-
README.md | 12 +--
examples/graph_visual_with_html.py | 6 +-
examples/graph_visual_with_neo4j.py | 30 +++---
examples/lightrag_openai_compatible_demo.py | 27 ++++--
examples/lightrag_siliconcloud_demo.py | 2 +-
examples/vram_management_demo.py | 36 +++++--
lightrag/llm.py | 101 ++++++++++++--------
lightrag/utils.py | 46 +++++----
requirements.txt | 4 +-
11 files changed, 175 insertions(+), 95 deletions(-)
diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml
index 32886cb0..7c12e0a2 100644
--- a/.github/workflows/linting.yaml
+++ b/.github/workflows/linting.yaml
@@ -15,7 +15,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v2
-
+
- name: Set up Python
uses: actions/setup-python@v2
with:
@@ -27,4 +27,4 @@ jobs:
pip install pre-commit
- name: Run pre-commit
- run: pre-commit run --all-files
\ No newline at end of file
+ run: pre-commit run --all-files
diff --git a/.gitignore b/.gitignore
index 5a41ae32..fd4bd830 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,4 @@ dickens/
book.txt
lightrag-dev/
.idea/
-dist/
\ No newline at end of file
+dist/
diff --git a/README.md b/README.md
index dbabcb56..abd7ceb9 100644
--- a/README.md
+++ b/README.md
@@ -58,8 +58,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
-# import nest_asyncio
-# nest_asyncio.apply()
+# import nest_asyncio
+# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
@@ -157,7 +157,7 @@ rag = LightRAG(
Using Ollama Models
-
+
* If you want to use Ollama models, you only need to set LightRAG as follows:
```python
@@ -328,8 +328,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
- e.displayName = node.id
- REMOVE e:Entity
+ e.displayName = node.id
+ REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -382,7 +382,7 @@ def main():
except Exception as e:
print(f"Error occurred: {e}")
-
+
finally:
driver.close()
diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py
index b455e6de..e4337a54 100644
--- a/examples/graph_visual_with_html.py
+++ b/examples/graph_visual_with_html.py
@@ -3,7 +3,7 @@ from pyvis.network import Network
import random
# Load the GraphML file
-G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
+G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
# Create a Pyvis network
net = Network(notebook=True)
@@ -13,7 +13,7 @@ net.from_nx(G)
# Add colors to nodes
for node in net.nodes:
- node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
+ node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
# Save and display the network
-net.show('knowledge_graph.html')
\ No newline at end of file
+net.show("knowledge_graph.html")
diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py
index 22dde368..7377f21c 100644
--- a/examples/graph_visual_with_neo4j.py
+++ b/examples/graph_visual_with_neo4j.py
@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
+
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
json_data = xml_to_json(xml_path)
if json_data:
- with open(output_path, 'w', encoding='utf-8') as f:
+ with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
print("Failed to create JSON data")
return None
+
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
- batch = data[i:i + batch_size]
+ batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
+
def main():
# Paths
- xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
- json_file = os.path.join(WORKING_DIR, 'graph_data.json')
+ xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
+ json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
return
# Load nodes and edges
- nodes = json_data.get('nodes', [])
- edges = json_data.get('edges', [])
+ nodes = json_data.get("nodes", [])
+ edges = json_data.get("edges", [])
# Neo4j queries
create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
- e.displayName = node.id
- REMOVE e:Entity
+ e.displayName = node.id
+ REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -100,19 +103,24 @@ def main():
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
- session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
+ session.execute_write(
+ process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
+ )
# Insert edges in batches
- session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
+ session.execute_write(
+ process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
+ )
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
-
+
finally:
driver.close()
+
if __name__ == "__main__":
main()
diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py
index 25d3722c..2470fc00 100644
--- a/examples/lightrag_openai_compatible_demo.py
+++ b/examples/lightrag_openai_compatible_demo.py
@@ -52,6 +52,7 @@ async def test_funcs():
# asyncio.run(test_funcs())
+
async def main():
try:
embedding_dimension = await get_embedding_dim()
@@ -61,35 +62,47 @@ async def main():
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
- embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func
+ embedding_dim=embedding_dimension,
+ max_token_size=8192,
+ func=embedding_func,
),
)
-
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
+ )
)
# Perform local search
print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="local")
+ )
)
# Perform global search
print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+ rag.query(
+ "What are the top themes in this story?",
+ param=QueryParam(mode="global"),
+ )
)
# Perform hybrid search
print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+ rag.query(
+ "What are the top themes in this story?",
+ param=QueryParam(mode="hybrid"),
+ )
)
except Exception as e:
print(f"An error occurred: {e}")
+
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py
index 82cab228..a73f16c5 100644
--- a/examples/lightrag_siliconcloud_demo.py
+++ b/examples/lightrag_siliconcloud_demo.py
@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
texts,
model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("SILICONFLOW_API_KEY"),
- max_token_size=512
+ max_token_size=512,
)
diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py
index ec750254..c173b913 100644
--- a/examples/vram_management_demo.py
+++ b/examples/vram_management_demo.py
@@ -27,11 +27,12 @@ rag = LightRAG(
# Read all .txt files from the TEXT_FILES_DIR directory
texts = []
for filename in os.listdir(TEXT_FILES_DIR):
- if filename.endswith('.txt'):
+ if filename.endswith(".txt"):
file_path = os.path.join(TEXT_FILES_DIR, filename)
- with open(file_path, 'r', encoding='utf-8') as file:
+ with open(file_path, "r", encoding="utf-8") as file:
texts.append(file.read())
+
# Batch insert texts into LightRAG with a retry mechanism
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
rag.insert(texts)
return
except Exception as e:
- print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
+ print(
+ f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
+ )
time.sleep(delay)
raise RuntimeError("Failed to insert texts after multiple retries.")
+
insert_texts_with_retry(rag, texts)
# Perform different types of queries and handle potential errors
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
+ )
+ )
except Exception as e:
print(f"Error performing naive search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="local")
+ )
+ )
except Exception as e:
print(f"Error performing local search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="global")
+ )
+ )
except Exception as e:
print(f"Error performing global search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
+ )
+ )
except Exception as e:
print(f"Error performing hybrid search: {e}")
+
# Function to clear VRAM resources
def clear_vram():
os.system("sudo nvidia-smi --gpu-reset")
+
# Regularly clear VRAM to prevent overflow
clear_vram_interval = 3600 # Clear once every hour
start_time = time.time()
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 4dcf535c..eaaa2b75 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -7,7 +7,13 @@ import aiohttp
import numpy as np
import ollama
-from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
+from openai import (
+ AsyncOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ Timeout,
+ AsyncAzureOpenAI,
+)
import base64
import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
)
return response.choices[0].message.content
+
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
-async def azure_openai_complete_if_cache(model,
+async def azure_openai_complete_if_cache(
+ model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
- **kwargs):
+ **kwargs,
+):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
)
return response.choices[0].message.content
+
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
- hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
+ hf_tokenizer = AutoTokenizer.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
+ hf_model = AutoModelForCausalLM.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
@@ -328,8 +344,9 @@ async def gpt_4o_mini_complete(
**kwargs,
)
+
async def azure_openai_complete(
- prompt, system_prompt=None, history_messages=[], **kwargs
+ prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
@@ -339,6 +356,7 @@ async def azure_openai_complete(
**kwargs,
)
+
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -418,9 +436,11 @@ async def azure_openai_embedding(
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
@@ -440,35 +460,28 @@ async def siliconcloud_embedding(
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
- if api_key and not api_key.startswith('Bearer '):
- api_key = 'Bearer ' + api_key
+ if api_key and not api_key.startswith("Bearer "):
+ api_key = "Bearer " + api_key
- headers = {
- "Authorization": api_key,
- "Content-Type": "application/json"
- }
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
- payload = {
- "model": model,
- "input": truncate_texts,
- "encoding_format": "base64"
- }
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
- if 'code' in content:
+ if "code" in content:
raise ValueError(content)
- base64_strings = [item['embedding'] for item in content['data']]
-
+ base64_strings = [item["embedding"] for item in content["data"]]
+
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
- float_array = struct.unpack('<' + 'f' * n, decode_bytes)
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
@@ -563,6 +576,7 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
return embed_text
+
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +594,20 @@ class Model(BaseModel):
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
- gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
- kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
+ gen_func: Callable[[Any], str] = Field(
+ ...,
+ description="A function that generates the response from the llm. The response must be a string",
+ )
+ kwargs: Dict[str, Any] = Field(
+ ...,
+ description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
+ )
class Config:
arbitrary_types_allowed = True
-class MultiModel():
+class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
@@ -611,26 +631,31 @@ class MultiModel():
)
```
"""
+
def __init__(self, models: List[Model]):
self._models = models
self._current_model = 0
-
+
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
- self,
- prompt, system_prompt=None, history_messages=[], **kwargs
+ self, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
- kwargs.pop("model", None) # stop from overwriting the custom model name
+ kwargs.pop("model", None) # stop from overwriting the custom model name
next_model = self._next_model()
- args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
-
- return await next_model.gen_func(
- **args
+ args = dict(
+ prompt=prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ **next_model.kwargs,
)
+ return await next_model.gen_func(**args)
+
+
if __name__ == "__main__":
import asyncio
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 9a68c16b..0da4a51a 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
+
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
- data = {
- "nodes": [],
- "edges": []
- }
+ data = {"nodes": [], "edges": []}
# Use namespace
- namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
+ namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
- for node in root.findall('.//node', namespace):
+ for node in root.findall(".//node", namespace):
node_data = {
- "id": node.get('id').strip('"'),
- "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
- "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
- "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
+ "id": node.get("id").strip('"'),
+ "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
+ if node.find("./data[@key='d0']", namespace) is not None
+ else "",
+ "description": node.find("./data[@key='d1']", namespace).text
+ if node.find("./data[@key='d1']", namespace) is not None
+ else "",
+ "source_id": node.find("./data[@key='d2']", namespace).text
+ if node.find("./data[@key='d2']", namespace) is not None
+ else "",
}
data["nodes"].append(node_data)
- for edge in root.findall('.//edge', namespace):
+ for edge in root.findall(".//edge", namespace):
edge_data = {
- "source": edge.get('source').strip('"'),
- "target": edge.get('target').strip('"'),
- "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
- "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
- "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
- "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
+ "source": edge.get("source").strip('"'),
+ "target": edge.get("target").strip('"'),
+ "weight": float(edge.find("./data[@key='d3']", namespace).text)
+ if edge.find("./data[@key='d3']", namespace) is not None
+ else 0.0,
+ "description": edge.find("./data[@key='d4']", namespace).text
+ if edge.find("./data[@key='d4']", namespace) is not None
+ else "",
+ "keywords": edge.find("./data[@key='d5']", namespace).text
+ if edge.find("./data[@key='d5']", namespace) is not None
+ else "",
+ "source_id": edge.find("./data[@key='d6']", namespace).text
+ if edge.find("./data[@key='d6']", namespace) is not None
+ else "",
}
data["edges"].append(edge_data)
diff --git a/requirements.txt b/requirements.txt
index 5b3396fb..98f32b0a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,15 @@
accelerate
aioboto3
+aiohttp
graspologic
hnswlib
nano-vectordb
networkx
ollama
openai
+pyvis
tenacity
tiktoken
torch
transformers
xxhash
-pyvis
-aiohttp
\ No newline at end of file
From 6df870712eb0cc661f6e8d9fe55bdf35de9c670b Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Fri, 25 Oct 2024 19:25:26 +0800
Subject: [PATCH 18/35] fix Step_3_openai_compatible.py
---
reproduce/Step_3_openai_compatible.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py
index 2be5ea5c..5e2ef778 100644
--- a/reproduce/Step_3_openai_compatible.py
+++ b/reproduce/Step_3_openai_compatible.py
@@ -50,8 +50,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param):
try:
- result, context = await rag_instance.aquery(query_text, param=query_param)
- return {"query": query_text, "result": result, "context": context}, None
+ result = await rag_instance.aquery(query_text, param=query_param)
+ return {"query": query_text, "result": result}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}
From 3325d97fb7cf46faf6cf499e4c92936e0f1ab0e3 Mon Sep 17 00:00:00 2001
From: jatin009v
Date: Fri, 25 Oct 2024 18:39:55 +0530
Subject: [PATCH 19/35] Key Enhancements: Error Handling:
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Handled potential FileNotFoundError for README.md and requirements.txt.
Checked for missing required metadata and raised an informative error if any are missing.
Automated Package Discovery:
Replaced packages=["lightrag"] with setuptools.find_packages() to automatically find sub-packages and exclude test or documentation directories.
Additional Metadata:
Added Development Status in classifiers to indicate a "Beta" release (modify based on the project's maturity).
Used project_urls to link documentation, source code, and an issue tracker, which are standard for open-source projects.
Compatibility:
Included include_package_data=True to include additional files specified in MANIFEST.in.
These changes enhance the readability, reliability, and openness of the code, making it more contributor-friendly and ensuring it’s ready for open-source distribution.
---
setup.py | 74 ++++++++++++++++++++++++++++++++++++++++----------------
1 file changed, 53 insertions(+), 21 deletions(-)
diff --git a/setup.py b/setup.py
index 47222420..bdf49f02 100644
--- a/setup.py
+++ b/setup.py
@@ -1,39 +1,71 @@
import setuptools
+from pathlib import Path
-with open("README.md", "r", encoding="utf-8") as fh:
- long_description = fh.read()
+# Reading the long description from README.md
+def read_long_description():
+ try:
+ return Path("README.md").read_text(encoding="utf-8")
+ except FileNotFoundError:
+ return "A description of LightRAG is currently unavailable."
+# Retrieving metadata from __init__.py
+def retrieve_metadata():
+ vars2find = ["__author__", "__version__", "__url__"]
+ vars2readme = {}
+ try:
+ with open("./lightrag/__init__.py") as f:
+ for line in f.readlines():
+ for v in vars2find:
+ if line.startswith(v):
+ line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
+ vars2readme[v] = line.split("=")[1]
+ except FileNotFoundError:
+ raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
+
+ # Checking if all required variables are found
+ missing_vars = [v for v in vars2find if v not in vars2readme]
+ if missing_vars:
+ raise ValueError(f"Missing required metadata variables in __init__.py: {missing_vars}")
+
+ return vars2readme
-vars2find = ["__author__", "__version__", "__url__"]
-vars2readme = {}
-with open("./lightrag/__init__.py") as f:
- for line in f.readlines():
- for v in vars2find:
- if line.startswith(v):
- line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
- vars2readme[v] = line.split("=")[1]
+# Reading dependencies from requirements.txt
+def read_requirements():
+ deps = []
+ try:
+ with open("./requirements.txt") as f:
+ deps = [line.strip() for line in f if line.strip()]
+ except FileNotFoundError:
+ print("Warning: 'requirements.txt' not found. No dependencies will be installed.")
+ return deps
-deps = []
-with open("./requirements.txt") as f:
- for line in f.readlines():
- if not line.strip():
- continue
- deps.append(line.strip())
+metadata = retrieve_metadata()
+long_description = read_long_description()
+requirements = read_requirements()
setuptools.setup(
name="lightrag-hku",
- url=vars2readme["__url__"],
- version=vars2readme["__version__"],
- author=vars2readme["__author__"],
+ url=metadata["__url__"],
+ version=metadata["__version__"],
+ author=metadata["__author__"],
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
long_description=long_description,
long_description_content_type="text/markdown",
- packages=["lightrag"],
+ packages=setuptools.find_packages(exclude=("tests*", "docs*")), # Automatically find packages
classifiers=[
+ "Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
+ "Intended Audience :: Developers",
+ "Topic :: Software Development :: Libraries :: Python Modules",
],
python_requires=">=3.9",
- install_requires=deps,
+ install_requires=requirements,
+ include_package_data=True, # Includes non-code files from MANIFEST.in
+ project_urls={ # Additional project metadata
+ "Documentation": metadata.get("__url__", ""),
+ "Source": metadata.get("__url__", ""),
+ "Tracker": f"{metadata.get('__url__', '')}/issues" if metadata.get("__url__") else ""
+ },
)
From af1a7f66fa703ba42904fcbe70d3ef7fff317bbf Mon Sep 17 00:00:00 2001
From: "zhenjie.ye"
Date: Sat, 26 Oct 2024 00:37:03 +0800
Subject: [PATCH 20/35] add Algorithm Flowchart
---
README.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/README.md b/README.md
index dbabcb56..f2f5c20e 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,10 @@ This repository hosts the code of LightRAG. The structure of this code is based
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
+## Algorithm Flowchart
+
+
+
## Install
* Install from source (Recommend)
From aad2e1f2d609e7626e29da32e699e647930e8493 Mon Sep 17 00:00:00 2001
From: "zhenjie.ye"
Date: Sat, 26 Oct 2024 00:37:46 +0800
Subject: [PATCH 21/35] add Algorithm FLowchart
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index f2f5c20e..0f8659b1 100644
--- a/README.md
+++ b/README.md
@@ -30,6 +30,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
## Algorithm Flowchart
+
## Install
From 226f6f3d87febd3041017d3d0299a00138ce8832 Mon Sep 17 00:00:00 2001
From: tackhwa
Date: Sat, 26 Oct 2024 02:20:23 +0800
Subject: [PATCH 22/35] fix hf output bug
---
lightrag/llm.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 4dcf535c..692937fb 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -266,10 +266,11 @@ async def hf_model_if_cache(
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
+ inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
)
- response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
+ response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
From 87f8b7dba1a334459b7401a6b88e68fd7e0ecc33 Mon Sep 17 00:00:00 2001
From: tackhwa <55059307+tackhwa@users.noreply.github.com>
Date: Sat, 26 Oct 2024 02:42:40 +0800
Subject: [PATCH 23/35] Update token length
---
lightrag/llm.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 692937fb..ab459fc7 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -268,7 +268,7 @@ async def hf_model_if_cache(
).to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
- **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
+ **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
if hashing_kv is not None:
From eec29d041aaf314d69ff4eae7661dcdc9a21b107 Mon Sep 17 00:00:00 2001
From: Yazington
Date: Sat, 26 Oct 2024 00:11:21 -0400
Subject: [PATCH 24/35] fixing bug
---
lightrag/lightrag.py | 6 ++++--
lightrag/operate.py | 26 +++++++++++++-------------
2 files changed, 17 insertions(+), 15 deletions(-)
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 5137af42..3004f5ed 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -85,7 +85,9 @@ class LightRAG:
# LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
- llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
+ llm_model_name: str = (
+ "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
+ )
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
@@ -208,7 +210,7 @@ class LightRAG:
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
- knwoledge_graph_inst=self.chunk_entity_relation_graph,
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
diff --git a/lightrag/operate.py b/lightrag/operate.py
index a0729cd8..8a6820f5 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
- already_node = await knwoledge_graph_inst.get_node(entity_name)
+ already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
description=description,
source_id=source_id,
)
- await knwoledge_graph_inst.upsert_node(
+ await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_weights = []
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
already_description = []
already_keywords = []
- if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
- already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
+ if await knowledge_graph_inst.has_edge(src_id, tgt_id):
+ already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
for need_insert_id in [src_id, tgt_id]:
- if not (await knwoledge_graph_inst.has_node(need_insert_id)):
- await knwoledge_graph_inst.upsert_node(
+ if not (await knowledge_graph_inst.has_node(need_insert_id)):
+ await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config
)
- await knwoledge_graph_inst.upsert_edge(
+ await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
@@ -341,13 +341,13 @@ async def extract_entities(
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
- _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
+ _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
- _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
+ _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
@@ -384,7 +384,7 @@ async def extract_entities(
}
await relationships_vdb.upsert(data_for_vdb)
- return knwoledge_graph_inst
+ return knowledge_graph_inst
async def local_query(
From e07b9f05306e2b21b1acfe47b3c47118bf658de9 Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Sat, 26 Oct 2024 14:04:11 +0800
Subject: [PATCH 25/35] Update graph_visual_with_html.py
---
examples/graph_visual_with_html.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py
index e4337a54..11279b3a 100644
--- a/examples/graph_visual_with_html.py
+++ b/examples/graph_visual_with_html.py
@@ -6,7 +6,7 @@ import random
G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
# Create a Pyvis network
-net = Network(notebook=True)
+net = Network(height="100vh", notebook=True)
# Convert NetworkX graph to Pyvis network
net.from_nx(G)
From ea3e13b522afdb7533d1b1098b5d2461f98c1ad6 Mon Sep 17 00:00:00 2001
From: LarFii <834462287@qq.com>
Date: Sat, 26 Oct 2024 14:40:17 +0800
Subject: [PATCH 26/35] update version
---
lightrag/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/lightrag/__init__.py b/lightrag/__init__.py
index db81e005..8e76a260 100644
--- a/lightrag/__init__.py
+++ b/lightrag/__init__.py
@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
-__version__ = "0.0.7"
+__version__ = "0.0.8"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
From dba7dad6dff4c588b6fd6c0d59830f910df18650 Mon Sep 17 00:00:00 2001
From: "zhenjie.ye"
Date: Sat, 26 Oct 2024 15:56:48 +0800
Subject: [PATCH 27/35] [feat] Add API server implementation and endpoints
---
README.md | 119 ++++++++++++++
.../lightrag_api_openai_compatible_demo.py | 153 ++++++++++++++++++
2 files changed, 272 insertions(+)
create mode 100644 examples/lightrag_api_openai_compatible_demo.py
diff --git a/README.md b/README.md
index 7fab9a01..d11b1691 100644
--- a/README.md
+++ b/README.md
@@ -397,6 +397,125 @@ if __name__ == "__main__":
+## API Server Implementation
+
+LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
+
+### Setting up the API Server
+
+Click to expand setup instructions
+
+1. First, ensure you have the required dependencies:
+```bash
+pip install fastapi uvicorn pydantic
+```
+
+2. Set up your environment variables:
+```bash
+export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
+```
+
+3. Run the API server:
+```bash
+python examples/lightrag_api_openai_compatible_demo.py
+```
+
+The server will start on `http://0.0.0.0:8020`.
+
+
+### API Endpoints
+
+The API server provides the following endpoints:
+
+#### 1. Query Endpoint
+
+Click to view Query endpoint details
+
+- **URL:** `/query`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "query": "Your question here",
+ "mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/query" \
+ -H "Content-Type: application/json" \
+ -d '{"query": "What are the main themes?", "mode": "hybrid"}'
+```
+
+
+#### 2. Insert Text Endpoint
+
+Click to view Insert Text endpoint details
+
+- **URL:** `/insert`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "text": "Your text content here"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/insert" \
+ -H "Content-Type: application/json" \
+ -d '{"text": "Content to be inserted into RAG"}'
+```
+
+
+#### 3. Insert File Endpoint
+
+Click to view Insert File endpoint details
+
+- **URL:** `/insert_file`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "file_path": "path/to/your/file.txt"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/insert_file" \
+ -H "Content-Type: application/json" \
+ -d '{"file_path": "./book.txt"}'
+```
+
+
+#### 4. Health Check Endpoint
+
+Click to view Health Check endpoint details
+
+- **URL:** `/health`
+- **Method:** GET
+- **Example:**
+```bash
+curl -X GET "http://127.0.0.1:8020/health"
+```
+
+
+### Configuration
+
+The API server can be configured using environment variables:
+- `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
+- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
+
+### Error Handling
+
+Click to view error handling details
+
+The API includes comprehensive error handling:
+- File not found errors (404)
+- Processing errors (500)
+- Supports multiple file encodings (UTF-8 and GBK)
+
+
## Evaluation
### Dataset
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py
new file mode 100644
index 00000000..f8d105ea
--- /dev/null
+++ b/examples/lightrag_api_openai_compatible_demo.py
@@ -0,0 +1,153 @@
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+import os
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.utils import EmbeddingFunc
+import numpy as np
+from typing import Optional
+import asyncio
+import nest_asyncio
+
+# Apply nest_asyncio to solve event loop issues
+nest_asyncio.apply()
+
+DEFAULT_RAG_DIR="index_default"
+app = FastAPI(title="LightRAG API", description="API for RAG operations")
+
+# Configure working directory
+WORKING_DIR = os.environ.get('RAG_DIR', f'{DEFAULT_RAG_DIR}')
+print(f"WORKING_DIR: {WORKING_DIR}")
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+# LLM model function
+async def llm_model_func(
+ prompt, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ return await openai_complete_if_cache(
+ "gpt-4o-mini",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ api_key='YOUR_API_KEY',
+ base_url="YourURL/v1",
+ **kwargs,
+ )
+
+# Embedding function
+async def embedding_func(texts: list[str]) -> np.ndarray:
+ return await openai_embedding(
+ texts,
+ model="text-embedding-3-large",
+ api_key='YOUR_API_KEY',
+ base_url="YourURL/v1",
+ )
+
+# Initialize RAG instance
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=3072, max_token_size=8192, func=embedding_func
+ ),
+)
+
+# Data models
+class QueryRequest(BaseModel):
+ query: str
+ mode: str = "hybrid"
+
+class InsertRequest(BaseModel):
+ text: str
+
+class InsertFileRequest(BaseModel):
+ file_path: str
+
+class Response(BaseModel):
+ status: str
+ data: Optional[str] = None
+ message: Optional[str] = None
+
+# API routes
+@app.post("/query", response_model=Response)
+async def query_endpoint(request: QueryRequest):
+ try:
+ loop = asyncio.get_event_loop()
+ result = await loop.run_in_executor(
+ None,
+ lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
+ )
+ return Response(
+ status="success",
+ data=result
+ )
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.post("/insert", response_model=Response)
+async def insert_endpoint(request: InsertRequest):
+ try:
+ loop = asyncio.get_event_loop()
+ await loop.run_in_executor(None, lambda: rag.insert(request.text))
+ return Response(
+ status="success",
+ message="Text inserted successfully"
+ )
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.post("/insert_file", response_model=Response)
+async def insert_file(request: InsertFileRequest):
+ try:
+ # Check if file exists
+ if not os.path.exists(request.file_path):
+ raise HTTPException(
+ status_code=404,
+ detail=f"File not found: {request.file_path}"
+ )
+
+ # Read file content
+ try:
+ with open(request.file_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+ except UnicodeDecodeError:
+ # If UTF-8 decoding fails, try other encodings
+ with open(request.file_path, 'r', encoding='gbk') as f:
+ content = f.read()
+
+ # Insert file content
+ loop = asyncio.get_event_loop()
+ await loop.run_in_executor(None, lambda: rag.insert(content))
+
+ return Response(
+ status="success",
+ message=f"File content from {request.file_path} inserted successfully"
+ )
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.get("/health")
+async def health_check():
+ return {"status": "healthy"}
+
+if __name__ == "__main__":
+ import uvicorn
+ uvicorn.run(app, host="0.0.0.0", port=8020)
+
+# Usage example
+# To run the server, use the following command in your terminal:
+# python lightrag_api_openai_compatible_demo.py
+
+# Example requests:
+# 1. Query:
+# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
+
+# 2. Insert text:
+# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
+
+# 3. Insert file:
+# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
+
+# 4. Health check:
+# curl -X GET "http://127.0.0.1:8020/health"
\ No newline at end of file
From 3d7b05b3e7d78435929ef5f1d878b9968ce5519d Mon Sep 17 00:00:00 2001
From: "zhenjie.ye"
Date: Sat, 26 Oct 2024 16:00:30 +0800
Subject: [PATCH 28/35] Refactor code formatting in
lightrag_api_openai_compatible_demo.py
---
.../lightrag_api_openai_compatible_demo.py | 29 ++++++++++++++-----
1 file changed, 22 insertions(+), 7 deletions(-)
diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py
index f8d105ea..ad9560dc 100644
--- a/examples/lightrag_api_openai_compatible_demo.py
+++ b/examples/lightrag_api_openai_compatible_demo.py
@@ -12,7 +12,7 @@ import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
-DEFAULT_RAG_DIR="index_default"
+DEFAULT_RAG_DIR = "index_default"
app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory
@@ -22,6 +22,8 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# LLM model function
+
+
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -36,6 +38,8 @@ async def llm_model_func(
)
# Embedding function
+
+
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
@@ -54,29 +58,37 @@ rag = LightRAG(
)
# Data models
+
+
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
+
class InsertRequest(BaseModel):
text: str
+
class InsertFileRequest(BaseModel):
file_path: str
+
class Response(BaseModel):
status: str
data: Optional[str] = None
message: Optional[str] = None
# API routes
+
+
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
- None,
- lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
+ None,
+ lambda: rag.query(
+ request.query, param=QueryParam(mode=request.mode))
)
return Response(
status="success",
@@ -85,6 +97,7 @@ async def query_endpoint(request: QueryRequest):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+
@app.post("/insert", response_model=Response)
async def insert_endpoint(request: InsertRequest):
try:
@@ -97,6 +110,7 @@ async def insert_endpoint(request: InsertRequest):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+
@app.post("/insert_file", response_model=Response)
async def insert_file(request: InsertFileRequest):
try:
@@ -106,7 +120,7 @@ async def insert_file(request: InsertFileRequest):
status_code=404,
detail=f"File not found: {request.file_path}"
)
-
+
# Read file content
try:
with open(request.file_path, 'r', encoding='utf-8') as f:
@@ -115,11 +129,11 @@ async def insert_file(request: InsertFileRequest):
# If UTF-8 decoding fails, try other encodings
with open(request.file_path, 'r', encoding='gbk') as f:
content = f.read()
-
+
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
-
+
return Response(
status="success",
message=f"File content from {request.file_path} inserted successfully"
@@ -127,6 +141,7 @@ async def insert_file(request: InsertFileRequest):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@@ -150,4 +165,4 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# 4. Health check:
-# curl -X GET "http://127.0.0.1:8020/health"
\ No newline at end of file
+# curl -X GET "http://127.0.0.1:8020/health"
From 981be9e569af127e755612eaf94c07c19b193f2f Mon Sep 17 00:00:00 2001
From: "zhenjie.ye"
Date: Sat, 26 Oct 2024 16:09:36 +0800
Subject: [PATCH 29/35] Refactor code formatting in
lightrag_api_openai_compatible_demo.py
---
.../lightrag_api_openai_compatible_demo.py | 34 ++++++++-----------
1 file changed, 15 insertions(+), 19 deletions(-)
diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py
index ad9560dc..2cd262bb 100644
--- a/examples/lightrag_api_openai_compatible_demo.py
+++ b/examples/lightrag_api_openai_compatible_demo.py
@@ -16,7 +16,7 @@ DEFAULT_RAG_DIR = "index_default"
app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory
-WORKING_DIR = os.environ.get('RAG_DIR', f'{DEFAULT_RAG_DIR}')
+WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -32,11 +32,12 @@ async def llm_model_func(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
- api_key='YOUR_API_KEY',
+ api_key="YOUR_API_KEY",
base_url="YourURL/v1",
**kwargs,
)
+
# Embedding function
@@ -44,10 +45,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="text-embedding-3-large",
- api_key='YOUR_API_KEY',
+ api_key="YOUR_API_KEY",
base_url="YourURL/v1",
)
+
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
@@ -78,6 +80,7 @@ class Response(BaseModel):
data: Optional[str] = None
message: Optional[str] = None
+
# API routes
@@ -86,14 +89,9 @@ async def query_endpoint(request: QueryRequest):
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
- None,
- lambda: rag.query(
- request.query, param=QueryParam(mode=request.mode))
- )
- return Response(
- status="success",
- data=result
+ None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
)
+ return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -103,10 +101,7 @@ async def insert_endpoint(request: InsertRequest):
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(request.text))
- return Response(
- status="success",
- message="Text inserted successfully"
- )
+ return Response(status="success", message="Text inserted successfully")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -117,17 +112,16 @@ async def insert_file(request: InsertFileRequest):
# Check if file exists
if not os.path.exists(request.file_path):
raise HTTPException(
- status_code=404,
- detail=f"File not found: {request.file_path}"
+ status_code=404, detail=f"File not found: {request.file_path}"
)
# Read file content
try:
- with open(request.file_path, 'r', encoding='utf-8') as f:
+ with open(request.file_path, "r", encoding="utf-8") as f:
content = f.read()
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
- with open(request.file_path, 'r', encoding='gbk') as f:
+ with open(request.file_path, "r", encoding="gbk") as f:
content = f.read()
# Insert file content
@@ -136,7 +130,7 @@ async def insert_file(request: InsertFileRequest):
return Response(
status="success",
- message=f"File content from {request.file_path} inserted successfully"
+ message=f"File content from {request.file_path} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -146,8 +140,10 @@ async def insert_file(request: InsertFileRequest):
async def health_check():
return {"status": "healthy"}
+
if __name__ == "__main__":
import uvicorn
+
uvicorn.run(app, host="0.0.0.0", port=8020)
# Usage example
From 8deb30aa205de458b727a24345f8b4a511dabafc Mon Sep 17 00:00:00 2001
From: tackhwa
Date: Sat, 26 Oct 2024 16:11:15 +0800
Subject: [PATCH 30/35] support lmdeploy backend
---
examples/lightrag_lmdeploy_demo.py | 74 +++++++++++++++++++++
lightrag/llm.py | 100 +++++++++++++++++++++++++++++
requirements.txt | 1 +
3 files changed, 175 insertions(+)
create mode 100644 examples/lightrag_lmdeploy_demo.py
diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py
new file mode 100644
index 00000000..ea7ace0e
--- /dev/null
+++ b/examples/lightrag_lmdeploy_demo.py
@@ -0,0 +1,74 @@
+import os
+
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
+from lightrag.utils import EmbeddingFunc
+from transformers import AutoModel, AutoTokenizer
+
+WORKING_DIR = "./dickens"
+
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+async def lmdeploy_model_complete(
+ prompt=None, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ return await lmdeploy_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ ## please specify chat_template if your local path does not follow original HF file name,
+ ## or model_name is a pytorch model on huggingface.co,
+ ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
+ ## for a list of chat_template available in lmdeploy.
+ chat_template = "llama3",
+ # model_format ='awq', # if you are using awq quantization model.
+ # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
+ **kwargs,
+ )
+
+
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=lmdeploy_model_complete,
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
+ embedding_func=EmbeddingFunc(
+ embedding_dim=384,
+ max_token_size=5000,
+ func=lambda texts: hf_embedding(
+ texts,
+ tokenizer=AutoTokenizer.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ embed_model=AutoModel.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ ),
+ ),
+)
+
+
+with open("./book.txt", "r", encoding="utf-8") as f:
+ rag.insert(f.read())
+
+# Perform naive search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
+
+# Perform local search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
+
+# Perform global search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
+
+# Perform hybrid search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index bb0d6063..028084bd 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -322,6 +322,106 @@ async def ollama_model_if_cache(
return result
+@lru_cache(maxsize=1)
+def initialize_lmdeploy_pipeline(model, tp=1, chat_template=None, log_level='WARNING', model_format='hf', quant_policy=0):
+ from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
+ lmdeploy_pipe = pipeline(
+ model_path=model,
+ backend_config=TurbomindEngineConfig(tp=tp, model_format=model_format, quant_policy=quant_policy),
+ chat_template_config=ChatTemplateConfig(model_name=chat_template) if chat_template else None,
+ log_level='WARNING')
+ return lmdeploy_pipe
+
+
+async def lmdeploy_model_if_cache(
+ model, prompt, system_prompt=None, history_messages=[],
+ chat_template=None, model_format='hf',quant_policy=0, **kwargs
+) -> str:
+ """
+ Args:
+ model (str): The path to the model.
+ It could be one of the following options:
+ - i) A local directory path of a turbomind model which is
+ converted by `lmdeploy convert` command or download
+ from ii) and iii).
+ - ii) The model_id of a lmdeploy-quantized model hosted
+ inside a model repo on huggingface.co, such as
+ "InternLM/internlm-chat-20b-4bit",
+ "lmdeploy/llama2-chat-70b-4bit", etc.
+ - iii) The model_id of a model hosted inside a model repo
+ on huggingface.co, such as "internlm/internlm-chat-7b",
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
+ and so on.
+ chat_template (str): needed when model is a pytorch model on
+ huggingface.co, such as "internlm-chat-7b",
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
+ and when the model name of local path did not match the original model name in HF.
+ tp (int): tensor parallel
+ prompt (Union[str, List[str]]): input texts to be completed.
+ do_preprocess (bool): whether pre-process the messages. Default to
+ True, which means chat_template will be applied.
+ skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be False.
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
+ Default to be False, which means greedy decoding will be applied.
+ """
+ try:
+ import lmdeploy
+ from lmdeploy import version_info, GenerationConfig
+ except:
+ raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
+
+ kwargs.pop("response_format", None)
+ max_new_tokens = kwargs.pop("max_tokens", 512)
+ tp = kwargs.pop('tp', 1)
+ skip_special_tokens = kwargs.pop('skip_special_tokens', False)
+ do_preprocess = kwargs.pop('do_preprocess', True)
+ do_sample = kwargs.pop('do_sample', False)
+ gen_params = kwargs
+
+ version = version_info
+ if do_sample is not None and version < (0, 6, 0):
+ raise RuntimeError(
+ '`do_sample` parameter is not supported by lmdeploy until '
+ f'v0.6.0, but currently using lmdeloy {lmdeploy.__version__}')
+ else:
+ do_sample = True
+ gen_params.update(do_sample=do_sample)
+
+ lmdeploy_pipe = initialize_lmdeploy_pipeline(
+ model=model,
+ tp=tp,
+ chat_template=chat_template,
+ model_format=model_format,
+ quant_policy=quant_policy,
+ log_level='WARNING')
+
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+ if hashing_kv is not None:
+ args_hash = compute_args_hash(model, messages)
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
+ if if_cache_return is not None:
+ return if_cache_return["return"]
+
+ gen_config = GenerationConfig(
+ skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params)
+
+ response = ""
+ async for res in lmdeploy_pipe.generate(messages, gen_config=gen_config,
+ do_preprocess=do_preprocess, stream_response=False, session_id=1):
+ response += res.response
+
+ if hashing_kv is not None:
+ await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
+ return response
+
+
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
diff --git a/requirements.txt b/requirements.txt
index 98f32b0a..6b0e025a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,3 +13,4 @@ tiktoken
torch
transformers
xxhash
+# lmdeploy[all]
From 883d6b7cc7fd1df29717ab5858f194cfbf0e3246 Mon Sep 17 00:00:00 2001
From: "zhenjie.ye"
Date: Sat, 26 Oct 2024 16:12:10 +0800
Subject: [PATCH 31/35] Refactor code formatting in
lightrag_api_openai_compatible_demo.py
---
lightrag/lightrag.py | 4 +---
lightrag/llm.py | 4 +++-
setup.py | 31 ++++++++++++++++++++++++-------
3 files changed, 28 insertions(+), 11 deletions(-)
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 3004f5ed..b84e22ef 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -85,9 +85,7 @@ class LightRAG:
# LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
- llm_model_name: str = (
- "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
- )
+ llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
diff --git a/lightrag/llm.py b/lightrag/llm.py
index bb0d6063..fd6b72d6 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -286,7 +286,9 @@ async def hf_model_if_cache(
output = hf_model.generate(
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
- response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
+ response_text = hf_tokenizer.decode(
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
+ )
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
diff --git a/setup.py b/setup.py
index bdf49f02..1b1f65f0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,7 @@
import setuptools
from pathlib import Path
+
# Reading the long description from README.md
def read_long_description():
try:
@@ -8,6 +9,7 @@ def read_long_description():
except FileNotFoundError:
return "A description of LightRAG is currently unavailable."
+
# Retrieving metadata from __init__.py
def retrieve_metadata():
vars2find = ["__author__", "__version__", "__url__"]
@@ -17,18 +19,26 @@ def retrieve_metadata():
for line in f.readlines():
for v in vars2find:
if line.startswith(v):
- line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
+ line = (
+ line.replace(" ", "")
+ .replace('"', "")
+ .replace("'", "")
+ .strip()
+ )
vars2readme[v] = line.split("=")[1]
except FileNotFoundError:
raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
-
+
# Checking if all required variables are found
missing_vars = [v for v in vars2find if v not in vars2readme]
if missing_vars:
- raise ValueError(f"Missing required metadata variables in __init__.py: {missing_vars}")
-
+ raise ValueError(
+ f"Missing required metadata variables in __init__.py: {missing_vars}"
+ )
+
return vars2readme
+
# Reading dependencies from requirements.txt
def read_requirements():
deps = []
@@ -36,9 +46,12 @@ def read_requirements():
with open("./requirements.txt") as f:
deps = [line.strip() for line in f if line.strip()]
except FileNotFoundError:
- print("Warning: 'requirements.txt' not found. No dependencies will be installed.")
+ print(
+ "Warning: 'requirements.txt' not found. No dependencies will be installed."
+ )
return deps
+
metadata = retrieve_metadata()
long_description = read_long_description()
requirements = read_requirements()
@@ -51,7 +64,9 @@ setuptools.setup(
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
long_description=long_description,
long_description_content_type="text/markdown",
- packages=setuptools.find_packages(exclude=("tests*", "docs*")), # Automatically find packages
+ packages=setuptools.find_packages(
+ exclude=("tests*", "docs*")
+ ), # Automatically find packages
classifiers=[
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
@@ -66,6 +81,8 @@ setuptools.setup(
project_urls={ # Additional project metadata
"Documentation": metadata.get("__url__", ""),
"Source": metadata.get("__url__", ""),
- "Tracker": f"{metadata.get('__url__', '')}/issues" if metadata.get("__url__") else ""
+ "Tracker": f"{metadata.get('__url__', '')}/issues"
+ if metadata.get("__url__")
+ else "",
},
)
From 2e703296d5e9f4a15547c1d1be3ecb53eab1925c Mon Sep 17 00:00:00 2001
From: tackhwa
Date: Sat, 26 Oct 2024 16:13:18 +0800
Subject: [PATCH 32/35] pre-commit
---
examples/lightrag_lmdeploy_demo.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py
index ea7ace0e..aeb96f71 100644
--- a/examples/lightrag_lmdeploy_demo.py
+++ b/examples/lightrag_lmdeploy_demo.py
@@ -10,10 +10,11 @@ WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
+
async def lmdeploy_model_complete(
prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await lmdeploy_model_if_cache(
model_name,
prompt,
@@ -23,7 +24,7 @@ async def lmdeploy_model_complete(
## or model_name is a pytorch model on huggingface.co,
## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
## for a list of chat_template available in lmdeploy.
- chat_template = "llama3",
+ chat_template="llama3",
# model_format ='awq', # if you are using awq quantization model.
# quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
**kwargs,
@@ -33,7 +34,7 @@ async def lmdeploy_model_complete(
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=lmdeploy_model_complete,
- llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
From 2cf3a85a0f09094372ae632cfce95cf1f649de76 Mon Sep 17 00:00:00 2001
From: tackhwa
Date: Sat, 26 Oct 2024 16:24:35 +0800
Subject: [PATCH 33/35] update do_preprocess
---
lightrag/llm.py | 77 ++++++++++++++++++++++++++++++++++---------------
1 file changed, 54 insertions(+), 23 deletions(-)
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 028084bd..d86886ea 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -286,7 +286,9 @@ async def hf_model_if_cache(
output = hf_model.generate(
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
- response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
+ response_text = hf_tokenizer.decode(
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
+ )
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
@@ -323,19 +325,38 @@ async def ollama_model_if_cache(
@lru_cache(maxsize=1)
-def initialize_lmdeploy_pipeline(model, tp=1, chat_template=None, log_level='WARNING', model_format='hf', quant_policy=0):
+def initialize_lmdeploy_pipeline(
+ model,
+ tp=1,
+ chat_template=None,
+ log_level="WARNING",
+ model_format="hf",
+ quant_policy=0,
+):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
+
lmdeploy_pipe = pipeline(
model_path=model,
- backend_config=TurbomindEngineConfig(tp=tp, model_format=model_format, quant_policy=quant_policy),
- chat_template_config=ChatTemplateConfig(model_name=chat_template) if chat_template else None,
- log_level='WARNING')
+ backend_config=TurbomindEngineConfig(
+ tp=tp, model_format=model_format, quant_policy=quant_policy
+ ),
+ chat_template_config=ChatTemplateConfig(model_name=chat_template)
+ if chat_template
+ else None,
+ log_level="WARNING",
+ )
return lmdeploy_pipe
async def lmdeploy_model_if_cache(
- model, prompt, system_prompt=None, history_messages=[],
- chat_template=None, model_format='hf',quant_policy=0, **kwargs
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ chat_template=None,
+ model_format="hf",
+ quant_policy=0,
+ **kwargs,
) -> str:
"""
Args:
@@ -354,36 +375,37 @@ async def lmdeploy_model_if_cache(
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
- "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
- in the decoding. Default to be False.
- do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
+ in the decoding. Default to be True.
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
- except:
+ except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
-
+
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
- tp = kwargs.pop('tp', 1)
- skip_special_tokens = kwargs.pop('skip_special_tokens', False)
- do_preprocess = kwargs.pop('do_preprocess', True)
- do_sample = kwargs.pop('do_sample', False)
+ tp = kwargs.pop("tp", 1)
+ skip_special_tokens = kwargs.pop("skip_special_tokens", True)
+ do_preprocess = kwargs.pop("do_preprocess", True)
+ do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
-
+
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
- '`do_sample` parameter is not supported by lmdeploy until '
- f'v0.6.0, but currently using lmdeloy {lmdeploy.__version__}')
+ "`do_sample` parameter is not supported by lmdeploy until "
+ f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
+ )
else:
do_sample = True
gen_params.update(do_sample=do_sample)
@@ -394,7 +416,8 @@ async def lmdeploy_model_if_cache(
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
- log_level='WARNING')
+ log_level="WARNING",
+ )
messages = []
if system_prompt:
@@ -410,11 +433,19 @@ async def lmdeploy_model_if_cache(
return if_cache_return["return"]
gen_config = GenerationConfig(
- skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params)
+ skip_special_tokens=skip_special_tokens,
+ max_new_tokens=max_new_tokens,
+ **gen_params,
+ )
response = ""
- async for res in lmdeploy_pipe.generate(messages, gen_config=gen_config,
- do_preprocess=do_preprocess, stream_response=False, session_id=1):
+ async for res in lmdeploy_pipe.generate(
+ messages,
+ gen_config=gen_config,
+ do_preprocess=do_preprocess,
+ stream_response=False,
+ session_id=1,
+ ):
response += res.response
if hashing_kv is not None:
From 6d84569703213a006570dd68346d4399967e18f0 Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Mon, 28 Oct 2024 09:59:40 +0800
Subject: [PATCH 34/35] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index d11b1691..bfdf920f 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@
-
+
From fee8575750cfdc7b03765048949536218a863531 Mon Sep 17 00:00:00 2001
From: zrguo <49157727+LarFii@users.noreply.github.com>
Date: Mon, 28 Oct 2024 15:08:41 +0800
Subject: [PATCH 35/35] Update README.md
---
README.md | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index bfdf920f..15696b57 100644
--- a/README.md
+++ b/README.md
@@ -237,7 +237,15 @@ rag.insert(["TEXT1", "TEXT2",...])
```python
# Incremental Insert: Insert new documents into an existing LightRAG instance
-rag = LightRAG(working_dir="./dickens")
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=embedding_dimension,
+ max_token_size=8192,
+ func=embedding_func,
+ ),
+)
with open("./newText.txt") as f:
rag.insert(f.read())