diff --git a/.gitignore b/.gitignore
index ec95f8a5..0bb7ea15 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,3 +22,5 @@ venv/
examples/input/
examples/output/
.DS_Store
+#Remove config.ini from repo
+*.ini
\ No newline at end of file
diff --git a/README.md b/README.md
index eeecf734..4ef38fbe 100644
--- a/README.md
+++ b/README.md
@@ -81,7 +81,7 @@ Use the below Python snippet (in a script) to initialize LightRAG and perform qu
```python
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
+from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
@@ -177,7 +177,7 @@ async def llm_model_func(
)
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
@@ -233,7 +233,7 @@ If you want to use Ollama models, you need to pull model you plan to use and emb
Then you only need to set LightRAG as follows:
```python
-from lightrag.llm import ollama_model_complete, ollama_embedding
+from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
# Initialize LightRAG with Ollama model
@@ -245,7 +245,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(
+ func=lambda texts: ollama_embed(
texts,
embed_model="nomic-embed-text"
)
@@ -690,7 +690,7 @@ if __name__ == "__main__":
| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
-| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
+| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embed` |
| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
diff --git a/config.ini b/config.ini
deleted file mode 100644
index eb1d4e88..00000000
--- a/config.ini
+++ /dev/null
@@ -1,13 +0,0 @@
-[redis]
-uri = redis://localhost:6379
-
-[neo4j]
-uri = #
-username = neo4j
-password = 12345678
-
-[milvus]
-uri = #
-user = root
-password = Milvus
-db_name = lightrag
diff --git a/examples/insert_custom_kg.py b/examples/insert_custom_kg.py
index 1c02ea25..50ad925e 100644
--- a/examples/insert_custom_kg.py
+++ b/examples/insert_custom_kg.py
@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG
-from lightrag.llm import gpt_4o_mini_complete
+from lightrag.llm.openai import gpt_4o_mini_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
diff --git a/examples/lightrag_api_ollama_demo.py b/examples/lightrag_api_ollama_demo.py
index 36df1262..634264d3 100644
--- a/examples/lightrag_api_ollama_demo.py
+++ b/examples/lightrag_api_ollama_demo.py
@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_embedding, ollama_model_complete
+from lightrag.llm.ollama import ollama_embed, ollama_model_complete
from lightrag.utils import EmbeddingFunc
from typing import Optional
import asyncio
@@ -38,7 +38,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(
+ func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),
diff --git a/examples/lightrag_api_open_webui_demo.py b/examples/lightrag_api_open_webui_demo.py
index 17e1817e..88454da8 100644
--- a/examples/lightrag_api_open_webui_demo.py
+++ b/examples/lightrag_api_open_webui_demo.py
@@ -9,7 +9,7 @@ from typing import Optional
import os
import logging
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_model_complete, ollama_embed
+from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
import nest_asyncio
diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py
index 1749753b..8173dc5b 100644
--- a/examples/lightrag_api_openai_compatible_demo.py
+++ b/examples/lightrag_api_openai_compatible_demo.py
@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from typing import Optional
@@ -48,7 +48,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model=EMBEDDING_MODEL,
)
diff --git a/examples/lightrag_api_oracle_demo.py b/examples/lightrag_api_oracle_demo.py
index 65ac4ddd..602ca900 100644
--- a/examples/lightrag_api_oracle_demo.py
+++ b/examples/lightrag_api_oracle_demo.py
@@ -13,7 +13,7 @@ from pathlib import Path
import asyncio
import nest_asyncio
from lightrag import LightRAG, QueryParam
-from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -64,7 +64,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model=EMBEDDING_MODEL,
api_key=APIKEY,
diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py
index 7e18ea57..6bb6c7d4 100644
--- a/examples/lightrag_bedrock_demo.py
+++ b/examples/lightrag_bedrock_demo.py
@@ -6,7 +6,7 @@ import os
import logging
from lightrag import LightRAG, QueryParam
-from lightrag.llm import bedrock_complete, bedrock_embedding
+from lightrag.llm.bedrock import bedrock_complete, bedrock_embed
from lightrag.utils import EmbeddingFunc
logging.getLogger("aiobotocore").setLevel(logging.WARNING)
@@ -20,7 +20,7 @@ rag = LightRAG(
llm_model_func=bedrock_complete,
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
embedding_func=EmbeddingFunc(
- embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
+ embedding_dim=1024, max_token_size=8192, func=bedrock_embed
),
)
diff --git a/examples/lightrag_hf_demo.py b/examples/lightrag_hf_demo.py
index 91033e50..a5088e54 100644
--- a/examples/lightrag_hf_demo.py
+++ b/examples/lightrag_hf_demo.py
@@ -1,7 +1,7 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import hf_model_complete, hf_embedding
+from lightrag.llm.hf import hf_model_complete, hf_embed
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
@@ -17,7 +17,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
- func=lambda texts: hf_embedding(
+ func=lambda texts: hf_embed(
texts,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
diff --git a/examples/lightrag_jinaai_demo.py b/examples/lightrag_jinaai_demo.py
index 4daead75..0378b61b 100644
--- a/examples/lightrag_jinaai_demo.py
+++ b/examples/lightrag_jinaai_demo.py
@@ -1,13 +1,14 @@
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
-from lightrag.llm import jina_embedding, openai_complete_if_cache
+from lightrag.llm.jina import jina_embed
+from lightrag.llm.openai import openai_complete_if_cache
import os
import asyncio
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await jina_embedding(texts, api_key="YourJinaAPIKey")
+ return await jina_embed(texts, api_key="YourJinaAPIKey")
WORKING_DIR = "./dickens"
diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py
index c0ee86cb..d12eb564 100644
--- a/examples/lightrag_lmdeploy_demo.py
+++ b/examples/lightrag_lmdeploy_demo.py
@@ -1,7 +1,8 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
+from lightrag.llm.lmdeploy import lmdeploy_model_if_cache
+from lightrag.llm.hf import hf_embed
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
@@ -42,7 +43,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
- func=lambda texts: hf_embedding(
+ func=lambda texts: hf_embed(
texts,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
diff --git a/examples/lightrag_nvidia_demo.py b/examples/lightrag_nvidia_demo.py
index 5af562b0..da4b46ff 100644
--- a/examples/lightrag_nvidia_demo.py
+++ b/examples/lightrag_nvidia_demo.py
@@ -3,7 +3,7 @@ import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import (
openai_complete_if_cache,
- nvidia_openai_embedding,
+ nvidia_openai_embed,
)
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -47,7 +47,7 @@ nvidia_embed_model = "nvidia/nv-embedqa-e5-v5"
async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
- return await nvidia_openai_embedding(
+ return await nvidia_openai_embed(
texts,
model=nvidia_embed_model, # maximum 512 token
# model="nvidia/llama-3.2-nv-embedqa-1b-v1",
@@ -60,7 +60,7 @@ async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
async def query_embedding_func(texts: list[str]) -> np.ndarray:
- return await nvidia_openai_embedding(
+ return await nvidia_openai_embed(
texts,
model=nvidia_embed_model, # maximum 512 token
# model="nvidia/llama-3.2-nv-embedqa-1b-v1",
diff --git a/examples/lightrag_ollama_age_demo.py b/examples/lightrag_ollama_age_demo.py
index 403843a7..d394ded4 100644
--- a/examples/lightrag_ollama_age_demo.py
+++ b/examples/lightrag_ollama_age_demo.py
@@ -4,7 +4,7 @@ import logging
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_embedding, ollama_model_complete
+from lightrag.llm.ollama import ollama_embed, ollama_model_complete
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens_age"
@@ -32,7 +32,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(
+ func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),
diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py
index 162900c4..95856fa2 100644
--- a/examples/lightrag_ollama_demo.py
+++ b/examples/lightrag_ollama_demo.py
@@ -3,7 +3,7 @@ import os
import inspect
import logging
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_model_complete, ollama_embedding
+from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
@@ -23,7 +23,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(
+ func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),
diff --git a/examples/lightrag_ollama_gremlin_demo.py b/examples/lightrag_ollama_gremlin_demo.py
index 35ffece8..fa7d4fb5 100644
--- a/examples/lightrag_ollama_gremlin_demo.py
+++ b/examples/lightrag_ollama_gremlin_demo.py
@@ -10,7 +10,7 @@ import os
# logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN)
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_embedding, ollama_model_complete
+from lightrag.llm.ollama import ollama_embed, ollama_model_complete
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens_gremlin"
@@ -41,7 +41,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(
+ func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),
diff --git a/examples/lightrag_ollama_neo4j_milvus_mongo_demo.py b/examples/lightrag_ollama_neo4j_milvus_mongo_demo.py
index 8d26ba65..b71489c7 100644
--- a/examples/lightrag_ollama_neo4j_milvus_mongo_demo.py
+++ b/examples/lightrag_ollama_neo4j_milvus_mongo_demo.py
@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_model_complete, ollama_embed
+from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
# WorkingDir
diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py
index 3494ae03..09673dd8 100644
--- a/examples/lightrag_openai_compatible_demo.py
+++ b/examples/lightrag_openai_compatible_demo.py
@@ -1,7 +1,7 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
-from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -26,7 +26,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
diff --git a/examples/lightrag_openai_compatible_demo_embedding_cache.py b/examples/lightrag_openai_compatible_demo_embedding_cache.py
index 69106d05..d696ce25 100644
--- a/examples/lightrag_openai_compatible_demo_embedding_cache.py
+++ b/examples/lightrag_openai_compatible_demo_embedding_cache.py
@@ -1,7 +1,7 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
-from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -26,7 +26,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py
index 9345ada5..93c4297c 100644
--- a/examples/lightrag_openai_compatible_stream_demo.py
+++ b/examples/lightrag_openai_compatible_stream_demo.py
@@ -1,7 +1,7 @@
import os
import inspect
from lightrag import LightRAG
-from lightrag.llm import openai_complete, openai_embedding
+from lightrag.llm import openai_complete, openai_embed
from lightrag.utils import EmbeddingFunc
from lightrag.lightrag import always_get_an_event_loop
from lightrag import QueryParam
@@ -24,7 +24,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
- func=lambda texts: openai_embedding(
+ func=lambda texts: openai_embed(
texts=texts,
model="text-embedding-bge-m3",
base_url="http://127.0.0.1:1234/v1",
diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py
index 29bc75ca..7a43a710 100644
--- a/examples/lightrag_openai_demo.py
+++ b/examples/lightrag_openai_demo.py
@@ -1,7 +1,7 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import gpt_4o_mini_complete
+from lightrag.llm.openai import gpt_4o_mini_complete
WORKING_DIR = "./dickens"
diff --git a/examples/lightrag_openai_neo4j_milvus_redis_demo.py b/examples/lightrag_openai_neo4j_milvus_redis_demo.py
index 3de1a657..75e110aa 100644
--- a/examples/lightrag_openai_neo4j_milvus_redis_demo.py
+++ b/examples/lightrag_openai_neo4j_milvus_redis_demo.py
@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_embed, openai_complete_if_cache
+from lightrag.llm.ollama import ollama_embed, openai_complete_if_cache
from lightrag.utils import EmbeddingFunc
# WorkingDir
diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py
index 6de6e0a7..47020fd6 100644
--- a/examples/lightrag_oracle_demo.py
+++ b/examples/lightrag_oracle_demo.py
@@ -3,7 +3,7 @@ import os
from pathlib import Path
import asyncio
from lightrag import LightRAG, QueryParam
-from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.oracle_impl import OracleDB
@@ -42,7 +42,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model=EMBEDMODEL,
api_key=APIKEY,
diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py
index ca15ccd0..5f4f86a1 100644
--- a/examples/lightrag_siliconcloud_demo.py
+++ b/examples/lightrag_siliconcloud_demo.py
@@ -1,7 +1,8 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
-from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
+from lightrag.llm.openai import openai_complete_if_cache
+from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
diff --git a/examples/lightrag_zhipu_demo.py b/examples/lightrag_zhipu_demo.py
index 0924656d..97a5042e 100644
--- a/examples/lightrag_zhipu_demo.py
+++ b/examples/lightrag_zhipu_demo.py
@@ -3,7 +3,7 @@ import logging
from lightrag import LightRAG, QueryParam
-from lightrag.llm import zhipu_complete, zhipu_embedding
+from lightrag.llm.zhipu import zhipu_complete, zhipu_embedding
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py
index d0461d84..4ed88602 100644
--- a/examples/lightrag_zhipu_postgres_demo.py
+++ b/examples/lightrag_zhipu_postgres_demo.py
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam
from lightrag.kg.postgres_impl import PostgreSQLDB
-from lightrag.llm import ollama_embedding, zhipu_complete
+from lightrag.llm.zhipu import ollama_embedding, zhipu_complete
from lightrag.utils import EmbeddingFunc
load_dotenv()
diff --git a/examples/test.py b/examples/test.py
index 80bcaa6d..67ee22eb 100644
--- a/examples/test.py
+++ b/examples/test.py
@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import gpt_4o_mini_complete
+from lightrag.llm.openai import gpt_4o_mini_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
diff --git a/examples/test_chromadb.py b/examples/test_chromadb.py
index df721bb2..0e6361ed 100644
--- a/examples/test_chromadb.py
+++ b/examples/test_chromadb.py
@@ -1,7 +1,7 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
-from lightrag.llm import gpt_4o_mini_complete, openai_embedding
+from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -35,7 +35,7 @@ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model=EMBEDDING_MODEL,
)
diff --git a/examples/test_neo4j.py b/examples/test_neo4j.py
index 0048fc17..ac5f7fb7 100644
--- a/examples/test_neo4j.py
+++ b/examples/test_neo4j.py
@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
-from lightrag.llm import gpt_4o_mini_complete
+from lightrag.llm.openai import gpt_4o_mini_complete
#########
diff --git a/examples/test_split_by_character.ipynb b/examples/test_split_by_character.ipynb
index e8e08b92..df5d938d 100644
--- a/examples/test_split_by_character.ipynb
+++ b/examples/test_split_by_character.ipynb
@@ -16,7 +16,7 @@
"import logging\n",
"import numpy as np\n",
"from lightrag import LightRAG, QueryParam\n",
- "from lightrag.llm import openai_complete_if_cache, openai_embedding\n",
+ "from lightrag.llm.openai import openai_complete_if_cache, openai_embed\n",
"from lightrag.utils import EmbeddingFunc\n",
"import nest_asyncio"
]
@@ -74,7 +74,7 @@
"\n",
"\n",
"async def embedding_func(texts: list[str]) -> np.ndarray:\n",
- " return await openai_embedding(\n",
+ " return await openai_embed(\n",
" texts,\n",
" model=\"ep-20241231173413-pgjmk\",\n",
" api_key=API,\n",
@@ -138,7 +138,7 @@
"\n",
"\n",
"async def embedding_func(texts: list[str]) -> np.ndarray:\n",
- " return await openai_embedding(\n",
+ " return await openai_embed(\n",
" texts,\n",
" model=\"ep-20241231173413-pgjmk\",\n",
" api_key=API,\n",
diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py
index c173b913..b8d0872e 100644
--- a/examples/vram_management_demo.py
+++ b/examples/vram_management_demo.py
@@ -1,7 +1,7 @@
import os
import time
from lightrag import LightRAG, QueryParam
-from lightrag.llm import ollama_model_complete, ollama_embedding
+from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
# Working directory and the directory path for text files
@@ -20,7 +20,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
+ func=lambda texts: ollama_embed(texts, embed_model="nomic-embed-text"),
),
)
diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py
index e505c1c5..92bd51ec 100644
--- a/lightrag/api/lightrag_server.py
+++ b/lightrag/api/lightrag_server.py
@@ -8,10 +8,6 @@ import time
import re
from typing import List, Dict, Any, Optional, Union
from lightrag import LightRAG, QueryParam
-from lightrag.llm import lollms_model_complete, lollms_embed
-from lightrag.llm import ollama_model_complete, ollama_embed
-from lightrag.llm import openai_complete_if_cache, openai_embedding
-from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
@@ -720,6 +716,20 @@ def create_app(args):
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
+ if args.llm_binding_host == "lollms" or args.embedding_binding == "lollms":
+ from lightrag.llm.lollms import lollms_model_complete, lollms_embed
+ if args.llm_binding_host == "ollama" or args.embedding_binding == "ollama":
+ from lightrag.llm.ollama import ollama_model_complete, ollama_embed
+ if args.llm_binding_host == "openai" or args.embedding_binding == "openai":
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
+ if (
+ args.llm_binding_host == "azure_openai"
+ or args.embedding_binding == "azure_openai"
+ ):
+ from lightrag.llm.azure_openai import (
+ azure_openai_complete_if_cache,
+ azure_openai_embed,
+ )
async def openai_alike_model_complete(
prompt,
@@ -773,13 +783,13 @@ def create_app(args):
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "ollama"
- else azure_openai_embedding(
+ else azure_openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "azure_openai"
- else openai_embedding(
+ else openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
api_key=args.embedding_binding_api_key,
diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt
index 74776828..fc5afd58 100644
--- a/lightrag/api/requirements.txt
+++ b/lightrag/api/requirements.txt
@@ -1,4 +1,3 @@
-aioboto3
ascii_colors
fastapi
nano_vectordb
diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py
new file mode 100644
index 00000000..249ba7e2
--- /dev/null
+++ b/lightrag/exceptions.py
@@ -0,0 +1,55 @@
+import httpx
+from typing import Literal
+
+class APIStatusError(Exception):
+ """Raised when an API response has a status code of 4xx or 5xx."""
+
+ response: httpx.Response
+ status_code: int
+ request_id: str | None
+
+ def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
+ super().__init__(message, response.request, body=body)
+ self.response = response
+ self.status_code = response.status_code
+ self.request_id = response.headers.get("x-request-id")
+
+class APIConnectionError(Exception):
+ def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
+ super().__init__(message, request, body=None)
+
+
+class BadRequestError(APIStatusError):
+ status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class AuthenticationError(APIStatusError):
+ status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class PermissionDeniedError(APIStatusError):
+ status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class NotFoundError(APIStatusError):
+ status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class ConflictError(APIStatusError):
+ status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class UnprocessableEntityError(APIStatusError):
+ status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
+
+
+class RateLimitError(APIStatusError):
+ status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
+
+class APITimeoutError(APIConnectionError):
+ def __init__(self, request: httpx.Request) -> None:
+ super().__init__(message="Request timed out.", request=request)
+
+
+class BadRequestError(APIStatusError):
+ status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py
index a861cb26..111fc6c8 100644
--- a/lightrag/kg/redis_impl.py
+++ b/lightrag/kg/redis_impl.py
@@ -1,7 +1,8 @@
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
-import aioredis
+# aioredis is a depricated library, replaced with redis
+from redis.asyncio import Redis
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
import json
@@ -11,7 +12,7 @@ import json
class RedisKVStorage(BaseKVStorage):
def __post_init__(self):
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
- self._redis = aioredis.from_url(redis_url, decode_responses=True)
+ self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}")
async def all_keys(self) -> list[str]:
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index c732d432..69b165cb 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -6,10 +6,6 @@ from datetime import datetime
from functools import partial
from typing import Type, cast, Dict
-from .llm import (
- gpt_4o_mini_complete,
- openai_embedding,
-)
from .operate import (
chunking_by_token_size,
extract_entities,
@@ -154,12 +150,12 @@ class LightRAG:
)
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
- embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
+ embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
- llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
+ llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
diff --git a/lightrag/llm.py b/lightrag/llm.py
index f1c21dab..3ca17725 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -1,1211 +1,5 @@
-import base64
-import copy
-import json
-import os
-import re
-import struct
-from functools import lru_cache
-from typing import List, Dict, Callable, Any, Union, Optional
-import aioboto3
-import aiohttp
-import numpy as np
-import ollama
-import torch
-from openai import (
- AsyncOpenAI,
- APIConnectionError,
- RateLimitError,
- APITimeoutError,
- AsyncAzureOpenAI,
-)
+from typing import List, Dict, Callable, Any
from pydantic import BaseModel, Field
-from tenacity import (
- retry,
- stop_after_attempt,
- wait_exponential,
- retry_if_exception_type,
-)
-from transformers import AutoTokenizer, AutoModelForCausalLM
-
-from .utils import (
- wrap_embedding_func_with_attrs,
- locate_json_string_body_from_string,
- safe_unicode_decode,
- logger,
-)
-
-import sys
-
-if sys.version_info < (3, 9):
- from typing import AsyncIterator
-else:
- from collections.abc import AsyncIterator
-
-os.environ["TOKENIZERS_PARALLELISM"] = "false"
-
-
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def openai_complete_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- base_url=None,
- api_key=None,
- **kwargs,
-) -> str:
- if api_key:
- os.environ["OPENAI_API_KEY"] = api_key
-
- openai_async_client = (
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
- )
- kwargs.pop("hashing_kv", None)
- kwargs.pop("keyword_extraction", None)
- messages = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.extend(history_messages)
- messages.append({"role": "user", "content": prompt})
-
- # 添加日志输出
- logger.debug("===== Query Input to LLM =====")
- logger.debug(f"Query: {prompt}")
- logger.debug(f"System prompt: {system_prompt}")
- logger.debug("Full context:")
- if "response_format" in kwargs:
- response = await openai_async_client.beta.chat.completions.parse(
- model=model, messages=messages, **kwargs
- )
- else:
- response = await openai_async_client.chat.completions.create(
- model=model, messages=messages, **kwargs
- )
-
- if hasattr(response, "__aiter__"):
-
- async def inner():
- async for chunk in response:
- content = chunk.choices[0].delta.content
- if content is None:
- continue
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
- yield content
-
- return inner()
- else:
- content = response.choices[0].message.content
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
- return content
-
-
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APIConnectionError)
- ),
-)
-async def azure_openai_complete_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- base_url=None,
- api_key=None,
- api_version=None,
- **kwargs,
-):
- if api_key:
- os.environ["AZURE_OPENAI_API_KEY"] = api_key
- if base_url:
- os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
- if api_version:
- os.environ["AZURE_OPENAI_API_VERSION"] = api_version
-
- openai_async_client = AsyncAzureOpenAI(
- azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
- )
- kwargs.pop("hashing_kv", None)
- messages = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.extend(history_messages)
- if prompt is not None:
- messages.append({"role": "user", "content": prompt})
-
- if "response_format" in kwargs:
- response = await openai_async_client.beta.chat.completions.parse(
- model=model, messages=messages, **kwargs
- )
- else:
- response = await openai_async_client.chat.completions.create(
- model=model, messages=messages, **kwargs
- )
-
- if hasattr(response, "__aiter__"):
-
- async def inner():
- async for chunk in response:
- if len(chunk.choices) == 0:
- continue
- content = chunk.choices[0].delta.content
- if content is None:
- continue
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
- yield content
-
- return inner()
- else:
- content = response.choices[0].message.content
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
- return content
-
-
-class BedrockError(Exception):
- """Generic error for issues related to Amazon Bedrock"""
-
-
-@retry(
- stop=stop_after_attempt(5),
- wait=wait_exponential(multiplier=1, max=60),
- retry=retry_if_exception_type((BedrockError)),
-)
-async def bedrock_complete_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- aws_access_key_id=None,
- aws_secret_access_key=None,
- aws_session_token=None,
- **kwargs,
-) -> str:
- os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
- "AWS_ACCESS_KEY_ID", aws_access_key_id
- )
- os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
- "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
- )
- os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
- "AWS_SESSION_TOKEN", aws_session_token
- )
- kwargs.pop("hashing_kv", None)
- # Fix message history format
- messages = []
- for history_message in history_messages:
- message = copy.copy(history_message)
- message["content"] = [{"text": message["content"]}]
- messages.append(message)
-
- # Add user prompt
- messages.append({"role": "user", "content": [{"text": prompt}]})
-
- # Initialize Converse API arguments
- args = {"modelId": model, "messages": messages}
-
- # Define system prompt
- if system_prompt:
- args["system"] = [{"text": system_prompt}]
-
- # Map and set up inference parameters
- inference_params_map = {
- "max_tokens": "maxTokens",
- "top_p": "topP",
- "stop_sequences": "stopSequences",
- }
- if inference_params := list(
- set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
- ):
- args["inferenceConfig"] = {}
- for param in inference_params:
- args["inferenceConfig"][inference_params_map.get(param, param)] = (
- kwargs.pop(param)
- )
-
- # Call model via Converse API
- session = aioboto3.Session()
- async with session.client("bedrock-runtime") as bedrock_async_client:
- try:
- response = await bedrock_async_client.converse(**args, **kwargs)
- except Exception as e:
- raise BedrockError(e)
-
- return response["output"]["message"]["content"][0]["text"]
-
-
-@lru_cache(maxsize=1)
-def initialize_hf_model(model_name):
- hf_tokenizer = AutoTokenizer.from_pretrained(
- model_name, device_map="auto", trust_remote_code=True
- )
- hf_model = AutoModelForCausalLM.from_pretrained(
- model_name, device_map="auto", trust_remote_code=True
- )
- if hf_tokenizer.pad_token is None:
- hf_tokenizer.pad_token = hf_tokenizer.eos_token
-
- return hf_model, hf_tokenizer
-
-
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def hf_model_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- **kwargs,
-) -> str:
- model_name = model
- hf_model, hf_tokenizer = initialize_hf_model(model_name)
- messages = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.extend(history_messages)
- messages.append({"role": "user", "content": prompt})
- kwargs.pop("hashing_kv", None)
- input_prompt = ""
- try:
- input_prompt = hf_tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- except Exception:
- try:
- ori_message = copy.deepcopy(messages)
- if messages[0]["role"] == "system":
- messages[1]["content"] = (
- ""
- + messages[0]["content"]
- + "\n"
- + messages[1]["content"]
- )
- messages = messages[1:]
- input_prompt = hf_tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- except Exception:
- len_message = len(ori_message)
- for msgid in range(len_message):
- input_prompt = (
- input_prompt
- + "<"
- + ori_message[msgid]["role"]
- + ">"
- + ori_message[msgid]["content"]
- + ""
- + ori_message[msgid]["role"]
- + ">\n"
- )
-
- input_ids = hf_tokenizer(
- input_prompt, return_tensors="pt", padding=True, truncation=True
- ).to("cuda")
- inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
- output = hf_model.generate(
- **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
- )
- response_text = hf_tokenizer.decode(
- output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
- )
-
- return response_text
-
-
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def ollama_model_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- **kwargs,
-) -> Union[str, AsyncIterator[str]]:
- stream = True if kwargs.get("stream") else False
- kwargs.pop("max_tokens", None)
- # kwargs.pop("response_format", None) # allow json
- host = kwargs.pop("host", None)
- timeout = kwargs.pop("timeout", None)
- kwargs.pop("hashing_kv", None)
- api_key = kwargs.pop("api_key", None)
- headers = (
- {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- if api_key
- else {"Content-Type": "application/json"}
- )
- ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
- messages = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.extend(history_messages)
- messages.append({"role": "user", "content": prompt})
-
- response = await ollama_client.chat(model=model, messages=messages, **kwargs)
- if stream:
- """cannot cache stream response"""
-
- async def inner():
- async for chunk in response:
- yield chunk["message"]["content"]
-
- return inner()
- else:
- return response["message"]["content"]
-
-
-async def lollms_model_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- base_url="http://localhost:9600",
- **kwargs,
-) -> Union[str, AsyncIterator[str]]:
- """Client implementation for lollms generation."""
-
- stream = True if kwargs.get("stream") else False
- api_key = kwargs.pop("api_key", None)
- headers = (
- {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- if api_key
- else {"Content-Type": "application/json"}
- )
-
- # Extract lollms specific parameters
- request_data = {
- "prompt": prompt,
- "model_name": model,
- "personality": kwargs.get("personality", -1),
- "n_predict": kwargs.get("n_predict", None),
- "stream": stream,
- "temperature": kwargs.get("temperature", 0.1),
- "top_k": kwargs.get("top_k", 50),
- "top_p": kwargs.get("top_p", 0.95),
- "repeat_penalty": kwargs.get("repeat_penalty", 0.8),
- "repeat_last_n": kwargs.get("repeat_last_n", 40),
- "seed": kwargs.get("seed", None),
- "n_threads": kwargs.get("n_threads", 8),
- }
-
- # Prepare the full prompt including history
- full_prompt = ""
- if system_prompt:
- full_prompt += f"{system_prompt}\n"
- for msg in history_messages:
- full_prompt += f"{msg['role']}: {msg['content']}\n"
- full_prompt += prompt
-
- request_data["prompt"] = full_prompt
- timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
-
- async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
- if stream:
-
- async def inner():
- async with session.post(
- f"{base_url}/lollms_generate", json=request_data
- ) as response:
- async for line in response.content:
- yield line.decode().strip()
-
- return inner()
- else:
- async with session.post(
- f"{base_url}/lollms_generate", json=request_data
- ) as response:
- return await response.text()
-
-
-@lru_cache(maxsize=1)
-def initialize_lmdeploy_pipeline(
- model,
- tp=1,
- chat_template=None,
- log_level="WARNING",
- model_format="hf",
- quant_policy=0,
-):
- from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
-
- lmdeploy_pipe = pipeline(
- model_path=model,
- backend_config=TurbomindEngineConfig(
- tp=tp, model_format=model_format, quant_policy=quant_policy
- ),
- chat_template_config=(
- ChatTemplateConfig(model_name=chat_template) if chat_template else None
- ),
- log_level="WARNING",
- )
- return lmdeploy_pipe
-
-
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def lmdeploy_model_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- chat_template=None,
- model_format="hf",
- quant_policy=0,
- **kwargs,
-) -> str:
- """
- Args:
- model (str): The path to the model.
- It could be one of the following options:
- - i) A local directory path of a turbomind model which is
- converted by `lmdeploy convert` command or download
- from ii) and iii).
- - ii) The model_id of a lmdeploy-quantized model hosted
- inside a model repo on huggingface.co, such as
- "InternLM/internlm-chat-20b-4bit",
- "lmdeploy/llama2-chat-70b-4bit", etc.
- - iii) The model_id of a model hosted inside a model repo
- on huggingface.co, such as "internlm/internlm-chat-7b",
- "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
- and so on.
- chat_template (str): needed when model is a pytorch model on
- huggingface.co, such as "internlm-chat-7b",
- "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
- and when the model name of local path did not match the original model name in HF.
- tp (int): tensor parallel
- prompt (Union[str, List[str]]): input texts to be completed.
- do_preprocess (bool): whether pre-process the messages. Default to
- True, which means chat_template will be applied.
- skip_special_tokens (bool): Whether or not to remove special tokens
- in the decoding. Default to be True.
- do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
- Default to be False, which means greedy decoding will be applied.
- """
- try:
- import lmdeploy
- from lmdeploy import version_info, GenerationConfig
- except Exception:
- raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
- kwargs.pop("hashing_kv", None)
- kwargs.pop("response_format", None)
- max_new_tokens = kwargs.pop("max_tokens", 512)
- tp = kwargs.pop("tp", 1)
- skip_special_tokens = kwargs.pop("skip_special_tokens", True)
- do_preprocess = kwargs.pop("do_preprocess", True)
- do_sample = kwargs.pop("do_sample", False)
- gen_params = kwargs
-
- version = version_info
- if do_sample is not None and version < (0, 6, 0):
- raise RuntimeError(
- "`do_sample` parameter is not supported by lmdeploy until "
- f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
- )
- else:
- do_sample = True
- gen_params.update(do_sample=do_sample)
-
- lmdeploy_pipe = initialize_lmdeploy_pipeline(
- model=model,
- tp=tp,
- chat_template=chat_template,
- model_format=model_format,
- quant_policy=quant_policy,
- log_level="WARNING",
- )
-
- messages = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
-
- messages.extend(history_messages)
- messages.append({"role": "user", "content": prompt})
-
- gen_config = GenerationConfig(
- skip_special_tokens=skip_special_tokens,
- max_new_tokens=max_new_tokens,
- **gen_params,
- )
-
- response = ""
- async for res in lmdeploy_pipe.generate(
- messages,
- gen_config=gen_config,
- do_preprocess=do_preprocess,
- stream_response=False,
- session_id=1,
- ):
- response += res.response
- return response
-
-
-class GPTKeywordExtractionFormat(BaseModel):
- high_level_keywords: List[str]
- low_level_keywords: List[str]
-
-
-async def openai_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> Union[str, AsyncIterator[str]]:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- if keyword_extraction:
- kwargs["response_format"] = "json"
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
- return await openai_complete_if_cache(
- model_name,
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
-
-
-async def gpt_4o_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- if keyword_extraction:
- kwargs["response_format"] = GPTKeywordExtractionFormat
- return await openai_complete_if_cache(
- "gpt-4o",
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
-
-
-async def gpt_4o_mini_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- if keyword_extraction:
- kwargs["response_format"] = GPTKeywordExtractionFormat
- return await openai_complete_if_cache(
- "gpt-4o-mini",
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
-
-
-async def nvidia_openai_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- result = await openai_complete_if_cache(
- "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- base_url="https://integrate.api.nvidia.com/v1",
- **kwargs,
- )
- if keyword_extraction: # TODO: use JSON API
- return locate_json_string_body_from_string(result)
- return result
-
-
-async def azure_openai_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- result = await azure_openai_complete_if_cache(
- os.getenv("LLM_MODEL", "gpt-4o-mini"),
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
- if keyword_extraction: # TODO: use JSON API
- return locate_json_string_body_from_string(result)
- return result
-
-
-async def bedrock_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- result = await bedrock_complete_if_cache(
- "anthropic.claude-3-haiku-20240307-v1:0",
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
- if keyword_extraction: # TODO: use JSON API
- return locate_json_string_body_from_string(result)
- return result
-
-
-async def hf_model_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> str:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
- result = await hf_model_if_cache(
- model_name,
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
- if keyword_extraction: # TODO: use JSON API
- return locate_json_string_body_from_string(result)
- return result
-
-
-async def ollama_model_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> Union[str, AsyncIterator[str]]:
- keyword_extraction = kwargs.pop("keyword_extraction", None)
- if keyword_extraction:
- kwargs["format"] = "json"
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
- return await ollama_model_if_cache(
- model_name,
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
-
-
-async def lollms_model_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-) -> Union[str, AsyncIterator[str]]:
- """Complete function for lollms model generation."""
-
- # Extract and remove keyword_extraction from kwargs if present
- keyword_extraction = kwargs.pop("keyword_extraction", None)
-
- # Get model name from config
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
-
- # If keyword extraction is needed, we might need to modify the prompt
- # or add specific parameters for JSON output (if lollms supports it)
- if keyword_extraction:
- # Note: You might need to adjust this based on how lollms handles structured output
- pass
-
- return await lollms_model_if_cache(
- model_name,
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
-
-
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def zhipu_complete_if_cache(
- prompt: Union[str, List[Dict[str, str]]],
- model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
- api_key: Optional[str] = None,
- system_prompt: Optional[str] = None,
- history_messages: List[Dict[str, str]] = [],
- **kwargs,
-) -> str:
- # dynamically load ZhipuAI
- try:
- from zhipuai import ZhipuAI
- except ImportError:
- raise ImportError("Please install zhipuai before initialize zhipuai backend.")
-
- if api_key:
- client = ZhipuAI(api_key=api_key)
- else:
- # please set ZHIPUAI_API_KEY in your environment
- # os.environ["ZHIPUAI_API_KEY"]
- client = ZhipuAI()
-
- messages = []
-
- if not system_prompt:
- system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
-
- # Add system prompt if provided
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.extend(history_messages)
- messages.append({"role": "user", "content": prompt})
-
- # Add debug logging
- logger.debug("===== Query Input to LLM =====")
- logger.debug(f"Query: {prompt}")
- logger.debug(f"System prompt: {system_prompt}")
-
- # Remove unsupported kwargs
- kwargs = {
- k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
- }
-
- response = client.chat.completions.create(model=model, messages=messages, **kwargs)
-
- return response.choices[0].message.content
-
-
-async def zhipu_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
-):
- # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
- keyword_extraction = kwargs.pop("keyword_extraction", None)
-
- if keyword_extraction:
- # Add a system prompt to guide the model to return JSON format
- extraction_prompt = """You are a helpful assistant that extracts keywords from text.
- Please analyze the content and extract two types of keywords:
- 1. High-level keywords: Important concepts and main themes
- 2. Low-level keywords: Specific details and supporting elements
-
- Return your response in this exact JSON format:
- {
- "high_level_keywords": ["keyword1", "keyword2"],
- "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
- }
-
- Only return the JSON, no other text."""
-
- # Combine with existing system prompt if any
- if system_prompt:
- system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
- else:
- system_prompt = extraction_prompt
-
- try:
- response = await zhipu_complete_if_cache(
- prompt=prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
-
- # Try to parse as JSON
- try:
- data = json.loads(response)
- return GPTKeywordExtractionFormat(
- high_level_keywords=data.get("high_level_keywords", []),
- low_level_keywords=data.get("low_level_keywords", []),
- )
- except json.JSONDecodeError:
- # If direct JSON parsing fails, try to extract JSON from text
- match = re.search(r"\{[\s\S]*\}", response)
- if match:
- try:
- data = json.loads(match.group())
- return GPTKeywordExtractionFormat(
- high_level_keywords=data.get("high_level_keywords", []),
- low_level_keywords=data.get("low_level_keywords", []),
- )
- except json.JSONDecodeError:
- pass
-
- # If all parsing fails, log warning and return empty format
- logger.warning(
- f"Failed to parse keyword extraction response: {response}"
- )
- return GPTKeywordExtractionFormat(
- high_level_keywords=[], low_level_keywords=[]
- )
- except Exception as e:
- logger.error(f"Error during keyword extraction: {str(e)}")
- return GPTKeywordExtractionFormat(
- high_level_keywords=[], low_level_keywords=[]
- )
- else:
- # For non-keyword-extraction, just return the raw response string
- return await zhipu_complete_if_cache(
- prompt=prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- **kwargs,
- )
-
-
-@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def zhipu_embedding(
- texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
-) -> np.ndarray:
- # dynamically load ZhipuAI
- try:
- from zhipuai import ZhipuAI
- except ImportError:
- raise ImportError("Please install zhipuai before initialize zhipuai backend.")
- if api_key:
- client = ZhipuAI(api_key=api_key)
- else:
- # please set ZHIPUAI_API_KEY in your environment
- # os.environ["ZHIPUAI_API_KEY"]
- client = ZhipuAI()
-
- # Convert single text to list if needed
- if isinstance(texts, str):
- texts = [texts]
-
- embeddings = []
- for text in texts:
- try:
- response = client.embeddings.create(model=model, input=[text], **kwargs)
- embeddings.append(response.data[0].embedding)
- except Exception as e:
- raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
-
- return np.array(embeddings)
-
-
-@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def openai_embedding(
- texts: list[str],
- model: str = "text-embedding-3-small",
- base_url: str = None,
- api_key: str = None,
-) -> np.ndarray:
- if api_key:
- os.environ["OPENAI_API_KEY"] = api_key
-
- openai_async_client = (
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
- )
- response = await openai_async_client.embeddings.create(
- model=model, input=texts, encoding_format="float"
- )
- return np.array([dp.embedding for dp in response.data])
-
-
-async def fetch_data(url, headers, data):
- async with aiohttp.ClientSession() as session:
- async with session.post(url, headers=headers, json=data) as response:
- response_json = await response.json()
- data_list = response_json.get("data", [])
- return data_list
-
-
-async def jina_embedding(
- texts: list[str],
- dimensions: int = 1024,
- late_chunking: bool = False,
- base_url: str = None,
- api_key: str = None,
-) -> np.ndarray:
- if api_key:
- os.environ["JINA_API_KEY"] = api_key
- url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
- }
- data = {
- "model": "jina-embeddings-v3",
- "normalized": True,
- "embedding_type": "float",
- "dimensions": f"{dimensions}",
- "late_chunking": late_chunking,
- "input": texts,
- }
- data_list = await fetch_data(url, headers, data)
- return np.array([dp["embedding"] for dp in data_list])
-
-
-@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def nvidia_openai_embedding(
- texts: list[str],
- model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
- # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
- base_url: str = "https://integrate.api.nvidia.com/v1",
- api_key: str = None,
- input_type: str = "passage", # query for retrieval, passage for embedding
- trunc: str = "NONE", # NONE or START or END
- encode: str = "float", # float or base64
-) -> np.ndarray:
- if api_key:
- os.environ["OPENAI_API_KEY"] = api_key
-
- openai_async_client = (
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
- )
- response = await openai_async_client.embeddings.create(
- model=model,
- input=texts,
- encoding_format=encode,
- extra_body={"input_type": input_type, "truncate": trunc},
- )
- return np.array([dp.embedding for dp in response.data])
-
-
-@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def azure_openai_embedding(
- texts: list[str],
- model: str = "text-embedding-3-small",
- base_url: str = None,
- api_key: str = None,
- api_version: str = None,
-) -> np.ndarray:
- if api_key:
- os.environ["AZURE_OPENAI_API_KEY"] = api_key
- if base_url:
- os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
- if api_version:
- os.environ["AZURE_OPENAI_API_VERSION"] = api_version
-
- openai_async_client = AsyncAzureOpenAI(
- azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
- )
-
- response = await openai_async_client.embeddings.create(
- model=model, input=texts, encoding_format="float"
- )
- return np.array([dp.embedding for dp in response.data])
-
-
-@retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
-)
-async def siliconcloud_embedding(
- texts: list[str],
- model: str = "netease-youdao/bce-embedding-base_v1",
- base_url: str = "https://api.siliconflow.cn/v1/embeddings",
- max_token_size: int = 512,
- api_key: str = None,
-) -> np.ndarray:
- if api_key and not api_key.startswith("Bearer "):
- api_key = "Bearer " + api_key
-
- headers = {"Authorization": api_key, "Content-Type": "application/json"}
-
- truncate_texts = [text[0:max_token_size] for text in texts]
-
- payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
-
- base64_strings = []
- async with aiohttp.ClientSession() as session:
- async with session.post(base_url, headers=headers, json=payload) as response:
- content = await response.json()
- if "code" in content:
- raise ValueError(content)
- base64_strings = [item["embedding"] for item in content["data"]]
-
- embeddings = []
- for string in base64_strings:
- decode_bytes = base64.b64decode(string)
- n = len(decode_bytes) // 4
- float_array = struct.unpack("<" + "f" * n, decode_bytes)
- embeddings.append(float_array)
- return np.array(embeddings)
-
-
-# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
-# @retry(
-# stop=stop_after_attempt(3),
-# wait=wait_exponential(multiplier=1, min=4, max=10),
-# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
-# )
-async def bedrock_embedding(
- texts: list[str],
- model: str = "amazon.titan-embed-text-v2:0",
- aws_access_key_id=None,
- aws_secret_access_key=None,
- aws_session_token=None,
-) -> np.ndarray:
- os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
- "AWS_ACCESS_KEY_ID", aws_access_key_id
- )
- os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
- "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
- )
- os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
- "AWS_SESSION_TOKEN", aws_session_token
- )
-
- session = aioboto3.Session()
- async with session.client("bedrock-runtime") as bedrock_async_client:
- if (model_provider := model.split(".")[0]) == "amazon":
- embed_texts = []
- for text in texts:
- if "v2" in model:
- body = json.dumps(
- {
- "inputText": text,
- # 'dimensions': embedding_dim,
- "embeddingTypes": ["float"],
- }
- )
- elif "v1" in model:
- body = json.dumps({"inputText": text})
- else:
- raise ValueError(f"Model {model} is not supported!")
-
- response = await bedrock_async_client.invoke_model(
- modelId=model,
- body=body,
- accept="application/json",
- contentType="application/json",
- )
-
- response_body = await response.get("body").json()
-
- embed_texts.append(response_body["embedding"])
- elif model_provider == "cohere":
- body = json.dumps(
- {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
- )
-
- response = await bedrock_async_client.invoke_model(
- model=model,
- body=body,
- accept="application/json",
- contentType="application/json",
- )
-
- response_body = json.loads(response.get("body").read())
-
- embed_texts = response_body["embeddings"]
- else:
- raise ValueError(f"Model provider '{model_provider}' is not supported!")
-
- return np.array(embed_texts)
-
-
-async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
- device = next(embed_model.parameters()).device
- input_ids = tokenizer(
- texts, return_tensors="pt", padding=True, truncation=True
- ).input_ids.to(device)
- with torch.no_grad():
- outputs = embed_model(input_ids)
- embeddings = outputs.last_hidden_state.mean(dim=1)
- if embeddings.dtype == torch.bfloat16:
- return embeddings.detach().to(torch.float32).cpu().numpy()
- else:
- return embeddings.detach().cpu().numpy()
-
-
-async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
- """
- Deprecated in favor of `embed`.
- """
- embed_text = []
- ollama_client = ollama.Client(**kwargs)
- for text in texts:
- data = ollama_client.embeddings(model=embed_model, prompt=text)
- embed_text.append(data["embedding"])
-
- return embed_text
-
-
-async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
- api_key = kwargs.pop("api_key", None)
- headers = (
- {"Content-Type": "application/json", "Authorization": api_key}
- if api_key
- else {"Content-Type": "application/json"}
- )
- kwargs["headers"] = headers
- ollama_client = ollama.Client(**kwargs)
- data = ollama_client.embed(model=embed_model, input=texts)
- return data["embeddings"]
-
-
-async def lollms_embed(
- texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
-) -> np.ndarray:
- """
- Generate embeddings for a list of texts using lollms server.
-
- Args:
- texts: List of strings to embed
- embed_model: Model name (not used directly as lollms uses configured vectorizer)
- base_url: URL of the lollms server
- **kwargs: Additional arguments passed to the request
-
- Returns:
- np.ndarray: Array of embeddings
- """
- api_key = kwargs.pop("api_key", None)
- headers = (
- {"Content-Type": "application/json", "Authorization": api_key}
- if api_key
- else {"Content-Type": "application/json"}
- )
- async with aiohttp.ClientSession(headers=headers) as session:
- embeddings = []
- for text in texts:
- request_data = {"text": text}
-
- async with session.post(
- f"{base_url}/lollms_embed",
- json=request_data,
- ) as response:
- result = await response.json()
- embeddings.append(result["vector"])
-
- return np.array(embeddings)
class Model(BaseModel):
@@ -1293,6 +87,8 @@ if __name__ == "__main__":
import asyncio
async def main():
+ from lightrag.llm.openai import gpt_4o_mini_complete
+
result = await gpt_4o_mini_complete("How are you?")
print(result)
diff --git a/lightrag/llm/__init__.py b/lightrag/llm/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py
new file mode 100644
index 00000000..83f89bcb
--- /dev/null
+++ b/lightrag/llm/azure_openai.py
@@ -0,0 +1,188 @@
+"""
+Azure OpenAI LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with aure openai's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - openai
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.azure_openai import azure_openai_model_complete, azure_openai_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+
+import os
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("openai"):
+ pm.install("openai")
+if not pm.is_installed("tenacity"):
+ pm.install("tenacity")
+
+from openai import (
+ AsyncAzureOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.utils import (
+ wrap_embedding_func_with_attrs,
+ locate_json_string_body_from_string,
+ safe_unicode_decode,
+)
+
+import numpy as np
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APIConnectionError)
+ ),
+)
+async def azure_openai_complete_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ base_url=None,
+ api_key=None,
+ api_version=None,
+ **kwargs,
+):
+ if api_key:
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
+ if base_url:
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
+ if api_version:
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
+
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
+ kwargs.pop("hashing_kv", None)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ if prompt is not None:
+ messages.append({"role": "user", "content": prompt})
+
+ if "response_format" in kwargs:
+ response = await openai_async_client.beta.chat.completions.parse(
+ model=model, messages=messages, **kwargs
+ )
+ else:
+ response = await openai_async_client.chat.completions.create(
+ model=model, messages=messages, **kwargs
+ )
+
+ if hasattr(response, "__aiter__"):
+
+ async def inner():
+ async for chunk in response:
+ if len(chunk.choices) == 0:
+ continue
+ content = chunk.choices[0].delta.content
+ if content is None:
+ continue
+ if r"\u" in content:
+ content = safe_unicode_decode(content.encode("utf-8"))
+ yield content
+
+ return inner()
+ else:
+ content = response.choices[0].message.content
+ if r"\u" in content:
+ content = safe_unicode_decode(content.encode("utf-8"))
+ return content
+
+
+async def azure_openai_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ result = await azure_openai_complete_if_cache(
+ os.getenv("LLM_MODEL", "gpt-4o-mini"),
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+ if keyword_extraction: # TODO: use JSON API
+ return locate_json_string_body_from_string(result)
+ return result
+
+@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def azure_openai_embed(
+ texts: list[str],
+ model: str = "text-embedding-3-small",
+ base_url: str = None,
+ api_key: str = None,
+ api_version: str = None,
+) -> np.ndarray:
+ if api_key:
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
+ if base_url:
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
+ if api_version:
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
+
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
+
+ response = await openai_async_client.embeddings.create(
+ model=model, input=texts, encoding_format="float"
+ )
+ return np.array([dp.embedding for dp in response.data])
+
diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py
new file mode 100644
index 00000000..c03ec42d
--- /dev/null
+++ b/lightrag/llm/bedrock.py
@@ -0,0 +1,229 @@
+"""
+Bedrock LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with Bedrock's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - aioboto3, tenacity
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.bebrock import bebrock_model_complete, bebrock_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+
+import sys
+import copy
+import os
+import json
+
+import pipmaster as pm # Pipmaster for dynamic library install
+if not pm.is_installed("aioboto3"):
+ pm.install("aioboto3")
+if not pm.is_installed("tenacity"):
+ pm.install("tenacity")
+import aioboto3
+import numpy as np
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.exceptions import (
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from lightrag.utils import (
+ locate_json_string_body_from_string,
+)
+
+class BedrockError(Exception):
+ """Generic error for issues related to Amazon Bedrock"""
+
+
+@retry(
+ stop=stop_after_attempt(5),
+ wait=wait_exponential(multiplier=1, max=60),
+ retry=retry_if_exception_type((BedrockError)),
+)
+async def bedrock_complete_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ aws_session_token=None,
+ **kwargs,
+) -> str:
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
+ )
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
+ )
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
+ "AWS_SESSION_TOKEN", aws_session_token
+ )
+ kwargs.pop("hashing_kv", None)
+ # Fix message history format
+ messages = []
+ for history_message in history_messages:
+ message = copy.copy(history_message)
+ message["content"] = [{"text": message["content"]}]
+ messages.append(message)
+
+ # Add user prompt
+ messages.append({"role": "user", "content": [{"text": prompt}]})
+
+ # Initialize Converse API arguments
+ args = {"modelId": model, "messages": messages}
+
+ # Define system prompt
+ if system_prompt:
+ args["system"] = [{"text": system_prompt}]
+
+ # Map and set up inference parameters
+ inference_params_map = {
+ "max_tokens": "maxTokens",
+ "top_p": "topP",
+ "stop_sequences": "stopSequences",
+ }
+ if inference_params := list(
+ set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
+ ):
+ args["inferenceConfig"] = {}
+ for param in inference_params:
+ args["inferenceConfig"][inference_params_map.get(param, param)] = (
+ kwargs.pop(param)
+ )
+
+ # Call model via Converse API
+ session = aioboto3.Session()
+ async with session.client("bedrock-runtime") as bedrock_async_client:
+ try:
+ response = await bedrock_async_client.converse(**args, **kwargs)
+ except Exception as e:
+ raise BedrockError(e)
+
+ return response["output"]["message"]["content"][0]["text"]
+
+
+async def bedrock_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ result = await bedrock_complete_if_cache(
+ "anthropic.claude-3-haiku-20240307-v1:0",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+ if keyword_extraction: # TODO: use JSON API
+ return locate_json_string_body_from_string(result)
+ return result
+
+
+# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
+# @retry(
+# stop=stop_after_attempt(3),
+# wait=wait_exponential(multiplier=1, min=4, max=10),
+# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
+# )
+async def bedrock_embed(
+ texts: list[str],
+ model: str = "amazon.titan-embed-text-v2:0",
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ aws_session_token=None,
+) -> np.ndarray:
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
+ )
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
+ )
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
+ "AWS_SESSION_TOKEN", aws_session_token
+ )
+
+ session = aioboto3.Session()
+ async with session.client("bedrock-runtime") as bedrock_async_client:
+ if (model_provider := model.split(".")[0]) == "amazon":
+ embed_texts = []
+ for text in texts:
+ if "v2" in model:
+ body = json.dumps(
+ {
+ "inputText": text,
+ # 'dimensions': embedding_dim,
+ "embeddingTypes": ["float"],
+ }
+ )
+ elif "v1" in model:
+ body = json.dumps({"inputText": text})
+ else:
+ raise ValueError(f"Model {model} is not supported!")
+
+ response = await bedrock_async_client.invoke_model(
+ modelId=model,
+ body=body,
+ accept="application/json",
+ contentType="application/json",
+ )
+
+ response_body = await response.get("body").json()
+
+ embed_texts.append(response_body["embedding"])
+ elif model_provider == "cohere":
+ body = json.dumps(
+ {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
+ )
+
+ response = await bedrock_async_client.invoke_model(
+ model=model,
+ body=body,
+ accept="application/json",
+ contentType="application/json",
+ )
+
+ response_body = json.loads(response.get("body").read())
+
+ embed_texts = response_body["embeddings"]
+ else:
+ raise ValueError(f"Model provider '{model_provider}' is not supported!")
+
+ return np.array(embed_texts)
diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py
new file mode 100644
index 00000000..dad8f3a8
--- /dev/null
+++ b/lightrag/llm/hf.py
@@ -0,0 +1,187 @@
+"""
+Hugging face LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with Hugging face's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - transformers
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.hf import hf_model_complete, hf_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+import copy
+import os
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("transformers"):
+ pm.install("transformers")
+if not pm.is_installed("torch"):
+ pm.install("torch")
+if not pm.is_installed("tenacity"):
+ pm.install("tenacity")
+
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from functools import lru_cache
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+from lightrag.exceptions import (
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from lightrag.utils import (
+ locate_json_string_body_from_string,
+)
+import torch
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+@lru_cache(maxsize=1)
+def initialize_hf_model(model_name):
+ hf_tokenizer = AutoTokenizer.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
+ hf_model = AutoModelForCausalLM.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
+ if hf_tokenizer.pad_token is None:
+ hf_tokenizer.pad_token = hf_tokenizer.eos_token
+
+ return hf_model, hf_tokenizer
+
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def hf_model_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ **kwargs,
+) -> str:
+ model_name = model
+ hf_model, hf_tokenizer = initialize_hf_model(model_name)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+ kwargs.pop("hashing_kv", None)
+ input_prompt = ""
+ try:
+ input_prompt = hf_tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
+ try:
+ ori_message = copy.deepcopy(messages)
+ if messages[0]["role"] == "system":
+ messages[1]["content"] = (
+ ""
+ + messages[0]["content"]
+ + "\n"
+ + messages[1]["content"]
+ )
+ messages = messages[1:]
+ input_prompt = hf_tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
+ len_message = len(ori_message)
+ for msgid in range(len_message):
+ input_prompt = (
+ input_prompt
+ + "<"
+ + ori_message[msgid]["role"]
+ + ">"
+ + ori_message[msgid]["content"]
+ + ""
+ + ori_message[msgid]["role"]
+ + ">\n"
+ )
+
+ input_ids = hf_tokenizer(
+ input_prompt, return_tensors="pt", padding=True, truncation=True
+ ).to("cuda")
+ inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
+ output = hf_model.generate(
+ **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
+ )
+ response_text = hf_tokenizer.decode(
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
+ )
+
+ return response_text
+
+
+
+async def hf_model_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ result = await hf_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+ if keyword_extraction: # TODO: use JSON API
+ return locate_json_string_body_from_string(result)
+ return result
+
+
+async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
+ device = next(embed_model.parameters()).device
+ input_ids = tokenizer(
+ texts, return_tensors="pt", padding=True, truncation=True
+ ).input_ids.to(device)
+ with torch.no_grad():
+ outputs = embed_model(input_ids)
+ embeddings = outputs.last_hidden_state.mean(dim=1)
+ if embeddings.dtype == torch.bfloat16:
+ return embeddings.detach().to(torch.float32).cpu().numpy()
+ else:
+ return embeddings.detach().cpu().numpy()
diff --git a/lightrag/llm/jina.py b/lightrag/llm/jina.py
new file mode 100644
index 00000000..07e680e1
--- /dev/null
+++ b/lightrag/llm/jina.py
@@ -0,0 +1,104 @@
+"""
+Jina Embedding Interface Module
+==========================
+
+This module provides interfaces for interacting with jina system,
+including embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added embedding generation
+
+Dependencies:
+ - tenacity
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.jina import jina_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+import os
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("lmdeploy"):
+ pm.install("lmdeploy")
+if not pm.is_installed("tenacity"):
+ pm.install("tenacity")
+
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.utils import (
+ wrap_embedding_func_with_attrs,
+ locate_json_string_body_from_string,
+ safe_unicode_decode,
+ logger,
+)
+
+from lightrag.types import GPTKeywordExtractionFormat
+from functools import lru_cache
+
+import numpy as np
+from typing import Union
+import aiohttp
+
+
+async def fetch_data(url, headers, data):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(url, headers=headers, json=data) as response:
+ response_json = await response.json()
+ data_list = response_json.get("data", [])
+ return data_list
+
+
+async def jina_embed(
+ texts: list[str],
+ dimensions: int = 1024,
+ late_chunking: bool = False,
+ base_url: str = None,
+ api_key: str = None,
+) -> np.ndarray:
+ if api_key:
+ os.environ["JINA_API_KEY"] = api_key
+ url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
+ }
+ data = {
+ "model": "jina-embeddings-v3",
+ "normalized": True,
+ "embedding_type": "float",
+ "dimensions": f"{dimensions}",
+ "late_chunking": late_chunking,
+ "input": texts,
+ }
+ data_list = await fetch_data(url, headers, data)
+ return np.array([dp["embedding"] for dp in data_list])
+
diff --git a/lightrag/llm/lmdeploy.py b/lightrag/llm/lmdeploy.py
new file mode 100644
index 00000000..7accbfab
--- /dev/null
+++ b/lightrag/llm/lmdeploy.py
@@ -0,0 +1,190 @@
+"""
+LMDeploy LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with LMDeploy's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - tenacity
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.lmdeploy import lmdeploy_model_complete, lmdeploy_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("lmdeploy"):
+ pm.install("lmdeploy[all]")
+if not pm.is_installed("tenacity"):
+ pm.install("tenacity")
+
+from lightrag.exceptions import (
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+
+from functools import lru_cache
+
+@lru_cache(maxsize=1)
+def initialize_lmdeploy_pipeline(
+ model,
+ tp=1,
+ chat_template=None,
+ log_level="WARNING",
+ model_format="hf",
+ quant_policy=0,
+):
+ from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
+
+ lmdeploy_pipe = pipeline(
+ model_path=model,
+ backend_config=TurbomindEngineConfig(
+ tp=tp, model_format=model_format, quant_policy=quant_policy
+ ),
+ chat_template_config=(
+ ChatTemplateConfig(model_name=chat_template) if chat_template else None
+ ),
+ log_level="WARNING",
+ )
+ return lmdeploy_pipe
+
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def lmdeploy_model_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ chat_template=None,
+ model_format="hf",
+ quant_policy=0,
+ **kwargs,
+) -> str:
+ """
+ Args:
+ model (str): The path to the model.
+ It could be one of the following options:
+ - i) A local directory path of a turbomind model which is
+ converted by `lmdeploy convert` command or download
+ from ii) and iii).
+ - ii) The model_id of a lmdeploy-quantized model hosted
+ inside a model repo on huggingface.co, such as
+ "InternLM/internlm-chat-20b-4bit",
+ "lmdeploy/llama2-chat-70b-4bit", etc.
+ - iii) The model_id of a model hosted inside a model repo
+ on huggingface.co, such as "internlm/internlm-chat-7b",
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
+ and so on.
+ chat_template (str): needed when model is a pytorch model on
+ huggingface.co, such as "internlm-chat-7b",
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
+ and when the model name of local path did not match the original model name in HF.
+ tp (int): tensor parallel
+ prompt (Union[str, List[str]]): input texts to be completed.
+ do_preprocess (bool): whether pre-process the messages. Default to
+ True, which means chat_template will be applied.
+ skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be True.
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
+ Default to be False, which means greedy decoding will be applied.
+ """
+ try:
+ import lmdeploy
+ from lmdeploy import version_info, GenerationConfig
+ except Exception:
+ raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
+ kwargs.pop("hashing_kv", None)
+ kwargs.pop("response_format", None)
+ max_new_tokens = kwargs.pop("max_tokens", 512)
+ tp = kwargs.pop("tp", 1)
+ skip_special_tokens = kwargs.pop("skip_special_tokens", True)
+ do_preprocess = kwargs.pop("do_preprocess", True)
+ do_sample = kwargs.pop("do_sample", False)
+ gen_params = kwargs
+
+ version = version_info
+ if do_sample is not None and version < (0, 6, 0):
+ raise RuntimeError(
+ "`do_sample` parameter is not supported by lmdeploy until "
+ f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
+ )
+ else:
+ do_sample = True
+ gen_params.update(do_sample=do_sample)
+
+ lmdeploy_pipe = initialize_lmdeploy_pipeline(
+ model=model,
+ tp=tp,
+ chat_template=chat_template,
+ model_format=model_format,
+ quant_policy=quant_policy,
+ log_level="WARNING",
+ )
+
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+
+ gen_config = GenerationConfig(
+ skip_special_tokens=skip_special_tokens,
+ max_new_tokens=max_new_tokens,
+ **gen_params,
+ )
+
+ response = ""
+ async for res in lmdeploy_pipe.generate(
+ messages,
+ gen_config=gen_config,
+ do_preprocess=do_preprocess,
+ stream_response=False,
+ session_id=1,
+ ):
+ response += res.response
+ return response
\ No newline at end of file
diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py
new file mode 100644
index 00000000..98c59734
--- /dev/null
+++ b/lightrag/llm/lollms.py
@@ -0,0 +1,222 @@
+"""
+LoLLMs (Lord of Large Language Models) Interface Module
+=====================================================
+
+This module provides the official interface for interacting with LoLLMs (Lord of Large Language and multimodal Systems),
+a unified framework for AI model interaction and deployment.
+
+LoLLMs is designed as a "one tool to rule them all" solution, providing seamless integration
+with various AI models while maintaining high performance and user-friendly interfaces.
+
+Author: ParisNeo
+Created: 2024-01-24
+License: Apache 2.0
+
+Copyright (c) 2024 ParisNeo
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+Version: 2.0.0
+
+Change Log:
+- 2.0.0 (2024-01-24):
+ * Added async support for model inference
+ * Implemented streaming capabilities
+ * Added embedding generation functionality
+ * Enhanced parameter handling
+ * Improved error handling and timeout management
+
+Dependencies:
+ - aiohttp
+ - numpy
+ - Python >= 3.10
+
+Features:
+ - Async text generation with streaming support
+ - Embedding generation
+ - Configurable model parameters
+ - System prompt and chat history support
+ - Timeout handling
+ - API key authentication
+
+Usage:
+ from llm_interfaces.lollms import lollms_model_complete, lollms_embed
+
+Project Repository: https://github.com/ParisNeo/lollms
+Documentation: https://github.com/ParisNeo/lollms/docs
+"""
+
+__version__ = "1.0.0"
+__author__ = "ParisNeo"
+__status__ = "Production"
+__project_url__ = "https://github.com/ParisNeo/lollms"
+__doc_url__ = "https://github.com/ParisNeo/lollms/docs"
+import sys
+if sys.version_info < (3, 9):
+ from typing import AsyncIterator
+else:
+ from collections.abc import AsyncIterator
+import pipmaster as pm # Pipmaster for dynamic library install
+if not pm.is_installed("aiohttp"):
+ pm.install("aiohttp")
+if not pm.is_installed("tenacity"):
+ pm.install("tenacity")
+
+import aiohttp
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.exceptions import (
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+
+from typing import Union, List
+import numpy as np
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def lollms_model_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ base_url="http://localhost:9600",
+ **kwargs,
+) -> Union[str, AsyncIterator[str]]:
+ """Client implementation for lollms generation."""
+
+ stream = True if kwargs.get("stream") else False
+ api_key = kwargs.pop("api_key", None)
+ headers = (
+ {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
+ if api_key
+ else {"Content-Type": "application/json"}
+ )
+
+ # Extract lollms specific parameters
+ request_data = {
+ "prompt": prompt,
+ "model_name": model,
+ "personality": kwargs.get("personality", -1),
+ "n_predict": kwargs.get("n_predict", None),
+ "stream": stream,
+ "temperature": kwargs.get("temperature", 0.1),
+ "top_k": kwargs.get("top_k", 50),
+ "top_p": kwargs.get("top_p", 0.95),
+ "repeat_penalty": kwargs.get("repeat_penalty", 0.8),
+ "repeat_last_n": kwargs.get("repeat_last_n", 40),
+ "seed": kwargs.get("seed", None),
+ "n_threads": kwargs.get("n_threads", 8),
+ }
+
+ # Prepare the full prompt including history
+ full_prompt = ""
+ if system_prompt:
+ full_prompt += f"{system_prompt}\n"
+ for msg in history_messages:
+ full_prompt += f"{msg['role']}: {msg['content']}\n"
+ full_prompt += prompt
+
+ request_data["prompt"] = full_prompt
+ timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
+
+ async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
+ if stream:
+
+ async def inner():
+ async with session.post(
+ f"{base_url}/lollms_generate", json=request_data
+ ) as response:
+ async for line in response.content:
+ yield line.decode().strip()
+
+ return inner()
+ else:
+ async with session.post(
+ f"{base_url}/lollms_generate", json=request_data
+ ) as response:
+ return await response.text()
+
+
+async def lollms_model_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> Union[str, AsyncIterator[str]]:
+ """Complete function for lollms model generation."""
+
+ # Extract and remove keyword_extraction from kwargs if present
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+
+ # Get model name from config
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+
+ # If keyword extraction is needed, we might need to modify the prompt
+ # or add specific parameters for JSON output (if lollms supports it)
+ if keyword_extraction:
+ # Note: You might need to adjust this based on how lollms handles structured output
+ pass
+
+ return await lollms_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+
+
+async def lollms_embed(
+ texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
+) -> np.ndarray:
+ """
+ Generate embeddings for a list of texts using lollms server.
+
+ Args:
+ texts: List of strings to embed
+ embed_model: Model name (not used directly as lollms uses configured vectorizer)
+ base_url: URL of the lollms server
+ **kwargs: Additional arguments passed to the request
+
+ Returns:
+ np.ndarray: Array of embeddings
+ """
+ api_key = kwargs.pop("api_key", None)
+ headers = (
+ {"Content-Type": "application/json", "Authorization": api_key}
+ if api_key
+ else {"Content-Type": "application/json"}
+ )
+ async with aiohttp.ClientSession(headers=headers) as session:
+ embeddings = []
+ for text in texts:
+ request_data = {"text": text}
+
+ async with session.post(
+ f"{base_url}/lollms_embed",
+ json=request_data,
+ ) as response:
+ result = await response.json()
+ embeddings.append(result["vector"])
+
+ return np.array(embeddings)
\ No newline at end of file
diff --git a/lightrag/llm/nvidia_openai.py b/lightrag/llm/nvidia_openai.py
new file mode 100644
index 00000000..3023af3f
--- /dev/null
+++ b/lightrag/llm/nvidia_openai.py
@@ -0,0 +1,112 @@
+"""
+OpenAI LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with openai's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - openai
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.nvidia_openai import nvidia_openai_model_complete, nvidia_openai_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+
+
+import sys
+import os
+
+if sys.version_info < (3, 9):
+ from typing import AsyncIterator
+else:
+ from collections.abc import AsyncIterator
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("openai"):
+ pm.install("openai")
+
+from openai import (
+ AsyncOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.utils import (
+ wrap_embedding_func_with_attrs,
+ locate_json_string_body_from_string,
+ safe_unicode_decode,
+ logger,
+)
+
+from lightrag.types import GPTKeywordExtractionFormat
+
+import numpy as np
+
+@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def nvidia_openai_embed(
+ texts: list[str],
+ model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
+ # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
+ base_url: str = "https://integrate.api.nvidia.com/v1",
+ api_key: str = None,
+ input_type: str = "passage", # query for retrieval, passage for embedding
+ trunc: str = "NONE", # NONE or START or END
+ encode: str = "float", # float or base64
+) -> np.ndarray:
+ if api_key:
+ os.environ["OPENAI_API_KEY"] = api_key
+
+ openai_async_client = (
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
+ )
+ response = await openai_async_client.embeddings.create(
+ model=model,
+ input=texts,
+ encoding_format=encode,
+ extra_body={"input_type": input_type, "truncate": trunc},
+ )
+ return np.array([dp.embedding for dp in response.data])
diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py
new file mode 100644
index 00000000..353ddacc
--- /dev/null
+++ b/lightrag/llm/ollama.py
@@ -0,0 +1,155 @@
+"""
+Ollama LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with Ollama's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - ollama
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.ollama_interface import ollama_model_complete, ollama_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+import sys
+if sys.version_info < (3, 9):
+ from typing import AsyncIterator
+else:
+ from collections.abc import AsyncIterator
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("ollama"):
+ pm.install("ollama")
+if not pm.is_installed("tenacity"):
+ pm.install("tenacity")
+
+import ollama
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+from lightrag.exceptions import (
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+import numpy as np
+from typing import Union
+
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def ollama_model_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ **kwargs,
+) -> Union[str, AsyncIterator[str]]:
+ stream = True if kwargs.get("stream") else False
+ kwargs.pop("max_tokens", None)
+ # kwargs.pop("response_format", None) # allow json
+ host = kwargs.pop("host", None)
+ timeout = kwargs.pop("timeout", None)
+ kwargs.pop("hashing_kv", None)
+ api_key = kwargs.pop("api_key", None)
+ headers = (
+ {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
+ if api_key
+ else {"Content-Type": "application/json"}
+ )
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+
+ response = await ollama_client.chat(model=model, messages=messages, **kwargs)
+ if stream:
+ """cannot cache stream response"""
+
+ async def inner():
+ async for chunk in response:
+ yield chunk["message"]["content"]
+
+ return inner()
+ else:
+ return response["message"]["content"]
+
+async def ollama_model_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> Union[str, AsyncIterator[str]]:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ if keyword_extraction:
+ kwargs["format"] = "json"
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ return await ollama_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
+ """
+ Deprecated in favor of `embed`.
+ """
+ embed_text = []
+ ollama_client = ollama.Client(**kwargs)
+ for text in texts:
+ data = ollama_client.embeddings(model=embed_model, prompt=text)
+ embed_text.append(data["embedding"])
+
+ return embed_text
+
+
+async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
+ api_key = kwargs.pop("api_key", None)
+ headers = (
+ {"Content-Type": "application/json", "Authorization": api_key}
+ if api_key
+ else {"Content-Type": "application/json"}
+ )
+ kwargs["headers"] = headers
+ ollama_client = ollama.Client(**kwargs)
+ data = ollama_client.embed(model=embed_model, input=texts)
+ return data["embeddings"]
\ No newline at end of file
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
new file mode 100644
index 00000000..d19fc7e1
--- /dev/null
+++ b/lightrag/llm/openai.py
@@ -0,0 +1,232 @@
+"""
+OpenAI LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with openai's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - openai
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.openai import openai_model_complete, openai_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+
+
+import sys
+import os
+
+if sys.version_info < (3, 9):
+ from typing import AsyncIterator
+else:
+ from collections.abc import AsyncIterator
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("openai"):
+ pm.install("openai")
+
+from openai import (
+ AsyncOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+from lightrag.utils import (
+ wrap_embedding_func_with_attrs,
+ locate_json_string_body_from_string,
+ safe_unicode_decode,
+ logger,
+)
+from lightrag.types import GPTKeywordExtractionFormat
+
+import numpy as np
+from typing import Union
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def openai_complete_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ base_url=None,
+ api_key=None,
+ **kwargs,
+) -> str:
+ if api_key:
+ os.environ["OPENAI_API_KEY"] = api_key
+
+ openai_async_client = (
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
+ )
+ kwargs.pop("hashing_kv", None)
+ kwargs.pop("keyword_extraction", None)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+
+ # 添加日志输出
+ logger.debug("===== Query Input to LLM =====")
+ logger.debug(f"Query: {prompt}")
+ logger.debug(f"System prompt: {system_prompt}")
+ logger.debug("Full context:")
+ if "response_format" in kwargs:
+ response = await openai_async_client.beta.chat.completions.parse(
+ model=model, messages=messages, **kwargs
+ )
+ else:
+ response = await openai_async_client.chat.completions.create(
+ model=model, messages=messages, **kwargs
+ )
+
+ if hasattr(response, "__aiter__"):
+
+ async def inner():
+ async for chunk in response:
+ content = chunk.choices[0].delta.content
+ if content is None:
+ continue
+ if r"\u" in content:
+ content = safe_unicode_decode(content.encode("utf-8"))
+ yield content
+
+ return inner()
+ else:
+ content = response.choices[0].message.content
+ if r"\u" in content:
+ content = safe_unicode_decode(content.encode("utf-8"))
+ return content
+
+
+
+async def openai_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> Union[str, AsyncIterator[str]]:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ if keyword_extraction:
+ kwargs["response_format"] = "json"
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ return await openai_complete_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+
+async def gpt_4o_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ if keyword_extraction:
+ kwargs["response_format"] = GPTKeywordExtractionFormat
+ return await openai_complete_if_cache(
+ "gpt-4o",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+
+async def gpt_4o_mini_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ if keyword_extraction:
+ kwargs["response_format"] = GPTKeywordExtractionFormat
+ return await openai_complete_if_cache(
+ "gpt-4o-mini",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+
+async def nvidia_openai_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ result = await openai_complete_if_cache(
+ "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ base_url="https://integrate.api.nvidia.com/v1",
+ **kwargs,
+ )
+ if keyword_extraction: # TODO: use JSON API
+ return locate_json_string_body_from_string(result)
+ return result
+
+
+
+@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def openai_embed(
+ texts: list[str],
+ model: str = "text-embedding-3-small",
+ base_url: str = None,
+ api_key: str = None,
+) -> np.ndarray:
+ if api_key:
+ os.environ["OPENAI_API_KEY"] = api_key
+
+ openai_async_client = (
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
+ )
+ response = await openai_async_client.embeddings.create(
+ model=model, input=texts, encoding_format="float"
+ )
+ return np.array([dp.embedding for dp in response.data])
diff --git a/lightrag/llm/siliconcloud.py b/lightrag/llm/siliconcloud.py
new file mode 100644
index 00000000..201fc93a
--- /dev/null
+++ b/lightrag/llm/siliconcloud.py
@@ -0,0 +1,121 @@
+"""
+SiliconCloud Embedding Interface Module
+==========================
+
+This module provides interfaces for interacting with SiliconCloud system,
+including embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added embedding generation
+
+Dependencies:
+ - tenacity
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.siliconcloud import siliconcloud_model_complete, siliconcloud_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+import sys
+import copy
+import os
+import json
+
+if sys.version_info < (3, 9):
+ from typing import AsyncIterator
+else:
+ from collections.abc import AsyncIterator
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("lmdeploy"):
+ pm.install("lmdeploy")
+
+from openai import (
+ AsyncOpenAI,
+ AsyncAzureOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.utils import (
+ wrap_embedding_func_with_attrs,
+ locate_json_string_body_from_string,
+ safe_unicode_decode,
+ logger,
+)
+
+from lightrag.types import GPTKeywordExtractionFormat
+from functools import lru_cache
+
+import numpy as np
+from typing import Union
+import aiohttp
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def siliconcloud_embedding(
+ texts: list[str],
+ model: str = "netease-youdao/bce-embedding-base_v1",
+ base_url: str = "https://api.siliconflow.cn/v1/embeddings",
+ max_token_size: int = 512,
+ api_key: str = None,
+) -> np.ndarray:
+ if api_key and not api_key.startswith("Bearer "):
+ api_key = "Bearer " + api_key
+
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
+
+ truncate_texts = [text[0:max_token_size] for text in texts]
+
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
+
+ base64_strings = []
+ async with aiohttp.ClientSession() as session:
+ async with session.post(base_url, headers=headers, json=payload) as response:
+ content = await response.json()
+ if "code" in content:
+ raise ValueError(content)
+ base64_strings = [item["embedding"] for item in content["data"]]
+
+ embeddings = []
+ for string in base64_strings:
+ decode_bytes = base64.b64decode(string)
+ n = len(decode_bytes) // 4
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
+ embeddings.append(float_array)
+ return np.array(embeddings)
diff --git a/lightrag/llm/zhipu.py b/lightrag/llm/zhipu.py
new file mode 100644
index 00000000..08d98108
--- /dev/null
+++ b/lightrag/llm/zhipu.py
@@ -0,0 +1,250 @@
+"""
+Zhipu LLM Interface Module
+==========================
+
+This module provides interfaces for interacting with LMDeploy's language models,
+including text generation and embedding capabilities.
+
+Author: Lightrag team
+Created: 2024-01-24
+License: MIT License
+
+Copyright (c) 2024 Lightrag
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+Version: 1.0.0
+
+Change Log:
+- 1.0.0 (2024-01-24): Initial release
+ * Added async chat completion support
+ * Added embedding generation
+ * Added stream response capability
+
+Dependencies:
+ - tenacity
+ - numpy
+ - pipmaster
+ - Python >= 3.10
+
+Usage:
+ from llm_interfaces.zhipu import zhipu_model_complete, zhipu_embed
+"""
+
+__version__ = "1.0.0"
+__author__ = "lightrag Team"
+__status__ = "Production"
+
+import sys
+import re
+import json
+
+if sys.version_info < (3, 9):
+ from typing import AsyncIterator
+else:
+ from collections.abc import AsyncIterator
+import pipmaster as pm # Pipmaster for dynamic library install
+
+# install specific modules
+if not pm.is_installed("zhipuai"):
+ pm.install("zhipuai")
+
+from openai import (
+ AsyncOpenAI,
+ AsyncAzureOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ APITimeoutError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_exponential,
+ retry_if_exception_type,
+)
+
+from lightrag.utils import (
+ wrap_embedding_func_with_attrs,
+ locate_json_string_body_from_string,
+ safe_unicode_decode,
+ logger,
+)
+
+from lightrag.types import GPTKeywordExtractionFormat
+from functools import lru_cache
+
+import numpy as np
+from typing import Union, List, Optional, Dict
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def zhipu_complete_if_cache(
+ prompt: Union[str, List[Dict[str, str]]],
+ model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
+ api_key: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ history_messages: List[Dict[str, str]] = [],
+ **kwargs,
+) -> str:
+ # dynamically load ZhipuAI
+ try:
+ from zhipuai import ZhipuAI
+ except ImportError:
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
+
+ if api_key:
+ client = ZhipuAI(api_key=api_key)
+ else:
+ # please set ZHIPUAI_API_KEY in your environment
+ # os.environ["ZHIPUAI_API_KEY"]
+ client = ZhipuAI()
+
+ messages = []
+
+ if not system_prompt:
+ system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
+
+ # Add system prompt if provided
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+
+ # Add debug logging
+ logger.debug("===== Query Input to LLM =====")
+ logger.debug(f"Query: {prompt}")
+ logger.debug(f"System prompt: {system_prompt}")
+
+ # Remove unsupported kwargs
+ kwargs = {
+ k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
+ }
+
+ response = client.chat.completions.create(model=model, messages=messages, **kwargs)
+
+ return response.choices[0].message.content
+
+
+async def zhipu_complete(
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+):
+ # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+
+ if keyword_extraction:
+ # Add a system prompt to guide the model to return JSON format
+ extraction_prompt = """You are a helpful assistant that extracts keywords from text.
+ Please analyze the content and extract two types of keywords:
+ 1. High-level keywords: Important concepts and main themes
+ 2. Low-level keywords: Specific details and supporting elements
+
+ Return your response in this exact JSON format:
+ {
+ "high_level_keywords": ["keyword1", "keyword2"],
+ "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
+ }
+
+ Only return the JSON, no other text."""
+
+ # Combine with existing system prompt if any
+ if system_prompt:
+ system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
+ else:
+ system_prompt = extraction_prompt
+
+ try:
+ response = await zhipu_complete_if_cache(
+ prompt=prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+ # Try to parse as JSON
+ try:
+ data = json.loads(response)
+ return GPTKeywordExtractionFormat(
+ high_level_keywords=data.get("high_level_keywords", []),
+ low_level_keywords=data.get("low_level_keywords", []),
+ )
+ except json.JSONDecodeError:
+ # If direct JSON parsing fails, try to extract JSON from text
+ match = re.search(r"\{[\s\S]*\}", response)
+ if match:
+ try:
+ data = json.loads(match.group())
+ return GPTKeywordExtractionFormat(
+ high_level_keywords=data.get("high_level_keywords", []),
+ low_level_keywords=data.get("low_level_keywords", []),
+ )
+ except json.JSONDecodeError:
+ pass
+
+ # If all parsing fails, log warning and return empty format
+ logger.warning(
+ f"Failed to parse keyword extraction response: {response}"
+ )
+ return GPTKeywordExtractionFormat(
+ high_level_keywords=[], low_level_keywords=[]
+ )
+ except Exception as e:
+ logger.error(f"Error during keyword extraction: {str(e)}")
+ return GPTKeywordExtractionFormat(
+ high_level_keywords=[], low_level_keywords=[]
+ )
+ else:
+ # For non-keyword-extraction, just return the raw response string
+ return await zhipu_complete_if_cache(
+ prompt=prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+
+@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+)
+async def zhipu_embedding(
+ texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
+) -> np.ndarray:
+ # dynamically load ZhipuAI
+ try:
+ from zhipuai import ZhipuAI
+ except ImportError:
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
+ if api_key:
+ client = ZhipuAI(api_key=api_key)
+ else:
+ # please set ZHIPUAI_API_KEY in your environment
+ # os.environ["ZHIPUAI_API_KEY"]
+ client = ZhipuAI()
+
+ # Convert single text to list if needed
+ if isinstance(texts, str):
+ texts = [texts]
+
+ embeddings = []
+ for text in texts:
+ try:
+ response = client.embeddings.create(model=model, input=[text], **kwargs)
+ embeddings.append(response.data[0].embedding)
+ except Exception as e:
+ raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
+
+ return np.array(embeddings)
\ No newline at end of file
diff --git a/lightrag/storage.py b/lightrag/storage.py
index 797e6487..f28d52d0 100644
--- a/lightrag/storage.py
+++ b/lightrag/storage.py
@@ -6,6 +6,8 @@ from dataclasses import dataclass
from typing import Any, Union, cast, Dict
import networkx as nx
import numpy as np
+import pipmaster as pm
+
from nano_vectordb import NanoVectorDB
import time
diff --git a/lightrag/types.py b/lightrag/types.py
new file mode 100644
index 00000000..bc0c1186
--- /dev/null
+++ b/lightrag/types.py
@@ -0,0 +1,6 @@
+from pydantic import BaseModel
+from typing import List
+
+class GPTKeywordExtractionFormat(BaseModel):
+ high_level_keywords: List[str]
+ low_level_keywords: List[str]
diff --git a/lightrag/utils.py b/lightrag/utils.py
index e902df85..898e66b4 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -535,7 +535,7 @@ class CacheData:
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
-
+ cache_type: str ="query"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
diff --git a/reproduce/Step_1_openai_compatible.py b/reproduce/Step_1_openai_compatible.py
index 8e67cfb8..09ca78c5 100644
--- a/reproduce/Step_1_openai_compatible.py
+++ b/reproduce/Step_1_openai_compatible.py
@@ -5,7 +5,7 @@ import numpy as np
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
-from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
## For Upstage API
@@ -25,7 +25,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py
index 5e2ef778..05e7d685 100644
--- a/reproduce/Step_3_openai_compatible.py
+++ b/reproduce/Step_3_openai_compatible.py
@@ -4,7 +4,7 @@ import json
import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
-from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -26,7 +26,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
- return await openai_embedding(
+ return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
diff --git a/requirements.txt b/requirements.txt
index 1bfe8101..84f0f63c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,22 +1,20 @@
accelerate
-aioboto3
aiofiles
aiohttp
-aioredis
+redis
asyncpg
configparser
# database packages
graspologic
gremlinpython
-hnswlib
nano-vectordb
neo4j
networkx
+# TODO : Remove specific databases and move the installation to their corresponding files
+# Use pipmaster for install if needed
numpy
-ollama
-openai
oracledb
pipmaster
psycopg-pool
@@ -33,14 +31,12 @@ python-dotenv
python-pptx
pyvis
setuptools
-# lmdeploy[all]
sqlalchemy
tenacity
# LLM packages
tiktoken
-torch
tqdm
-transformers
xxhash
+