Separated llms from the main llm.py file and fixed some deprication bugs

This commit is contained in:
Saifeddine ALOUI
2025-01-25 00:11:00 +01:00
parent 7e1638525c
commit 34018cb1e0
55 changed files with 2144 additions and 1301 deletions

2
.gitignore vendored
View File

@@ -22,3 +22,5 @@ venv/
examples/input/
examples/output/
.DS_Store
#Remove config.ini from repo
*.ini

View File

@@ -81,7 +81,7 @@ Use the below Python snippet (in a script) to initialize LightRAG and perform qu
```python
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
@@ -177,7 +177,7 @@ async def llm_model_func(
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
@@ -233,7 +233,7 @@ If you want to use Ollama models, you need to pull model you plan to use and emb
Then you only need to set LightRAG as follows:
```python
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
# Initialize LightRAG with Ollama model
@@ -245,7 +245,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
func=lambda texts: ollama_embed(
texts,
embed_model="nomic-embed-text"
)
@@ -690,7 +690,7 @@ if __name__ == "__main__":
| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embed` |
| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |

View File

@@ -1,13 +0,0 @@
[redis]
uri = redis://localhost:6379
[neo4j]
uri = #
username = neo4j
password = 12345678
[milvus]
uri = #
user = root
password = Milvus
db_name = lightrag

View File

@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG
from lightrag.llm import gpt_4o_mini_complete
from lightrag.llm.openai import gpt_4o_mini_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio

View File

@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_embedding, ollama_model_complete
from lightrag.llm.ollama import ollama_embed, ollama_model_complete
from lightrag.utils import EmbeddingFunc
from typing import Optional
import asyncio
@@ -38,7 +38,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),

View File

@@ -9,7 +9,7 @@ from typing import Optional
import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
import nest_asyncio

View File

@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from typing import Optional
@@ -48,7 +48,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model=EMBEDDING_MODEL,
)

View File

@@ -13,7 +13,7 @@ from pathlib import Path
import asyncio
import nest_asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -64,7 +64,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model=EMBEDDING_MODEL,
api_key=APIKEY,

View File

@@ -6,7 +6,7 @@ import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import bedrock_complete, bedrock_embedding
from lightrag.llm.bedrock import bedrock_complete, bedrock_embed
from lightrag.utils import EmbeddingFunc
logging.getLogger("aiobotocore").setLevel(logging.WARNING)
@@ -20,7 +20,7 @@ rag = LightRAG(
llm_model_func=bedrock_complete,
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
embedding_func=EmbeddingFunc(
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
embedding_dim=1024, max_token_size=8192, func=bedrock_embed
),
)

View File

@@ -1,7 +1,7 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import hf_model_complete, hf_embedding
from lightrag.llm.hf import hf_model_complete, hf_embed
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
@@ -17,7 +17,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
func=lambda texts: hf_embed(
texts,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"

View File

@@ -1,13 +1,14 @@
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
from lightrag.llm import jina_embedding, openai_complete_if_cache
from lightrag.llm.jina import jina_embed
from lightrag.llm.openai import openai_complete_if_cache
import os
import asyncio
async def embedding_func(texts: list[str]) -> np.ndarray:
return await jina_embedding(texts, api_key="YourJinaAPIKey")
return await jina_embed(texts, api_key="YourJinaAPIKey")
WORKING_DIR = "./dickens"

View File

@@ -1,7 +1,8 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
from lightrag.llm.lmdeploy import lmdeploy_model_if_cache
from lightrag.llm.hf import hf_embed
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
@@ -42,7 +43,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
func=lambda texts: hf_embed(
texts,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"

View File

@@ -3,7 +3,7 @@ import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import (
openai_complete_if_cache,
nvidia_openai_embedding,
nvidia_openai_embed,
)
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -47,7 +47,7 @@ nvidia_embed_model = "nvidia/nv-embedqa-e5-v5"
async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
return await nvidia_openai_embedding(
return await nvidia_openai_embed(
texts,
model=nvidia_embed_model, # maximum 512 token
# model="nvidia/llama-3.2-nv-embedqa-1b-v1",
@@ -60,7 +60,7 @@ async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
async def query_embedding_func(texts: list[str]) -> np.ndarray:
return await nvidia_openai_embedding(
return await nvidia_openai_embed(
texts,
model=nvidia_embed_model, # maximum 512 token
# model="nvidia/llama-3.2-nv-embedqa-1b-v1",

View File

@@ -4,7 +4,7 @@ import logging
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_embedding, ollama_model_complete
from lightrag.llm.ollama import ollama_embed, ollama_model_complete
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens_age"
@@ -32,7 +32,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),

View File

@@ -3,7 +3,7 @@ import os
import inspect
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
@@ -23,7 +23,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),

View File

@@ -10,7 +10,7 @@ import os
# logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN)
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_embedding, ollama_model_complete
from lightrag.llm.ollama import ollama_embed, ollama_model_complete
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens_gremlin"
@@ -41,7 +41,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),

View File

@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
# WorkingDir

View File

@@ -1,7 +1,7 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -26,7 +26,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),

View File

@@ -1,7 +1,7 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -26,7 +26,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),

View File

@@ -1,7 +1,7 @@
import os
import inspect
from lightrag import LightRAG
from lightrag.llm import openai_complete, openai_embedding
from lightrag.llm import openai_complete, openai_embed
from lightrag.utils import EmbeddingFunc
from lightrag.lightrag import always_get_an_event_loop
from lightrag import QueryParam
@@ -24,7 +24,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: openai_embedding(
func=lambda texts: openai_embed(
texts=texts,
model="text-embedding-bge-m3",
base_url="http://127.0.0.1:1234/v1",

View File

@@ -1,7 +1,7 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete
from lightrag.llm.openai import gpt_4o_mini_complete
WORKING_DIR = "./dickens"

View File

@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_embed, openai_complete_if_cache
from lightrag.llm.ollama import ollama_embed, openai_complete_if_cache
from lightrag.utils import EmbeddingFunc
# WorkingDir

View File

@@ -3,7 +3,7 @@ import os
from pathlib import Path
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.oracle_impl import OracleDB
@@ -42,7 +42,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model=EMBEDMODEL,
api_key=APIKEY,

View File

@@ -1,7 +1,8 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np

View File

@@ -3,7 +3,7 @@ import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import zhipu_complete, zhipu_embedding
from lightrag.llm.zhipu import zhipu_complete, zhipu_embedding
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"

View File

@@ -6,7 +6,7 @@ from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam
from lightrag.kg.postgres_impl import PostgreSQLDB
from lightrag.llm import ollama_embedding, zhipu_complete
from lightrag.llm.zhipu import ollama_embedding, zhipu_complete
from lightrag.utils import EmbeddingFunc
load_dotenv()

View File

@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete
from lightrag.llm.openai import gpt_4o_mini_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio

View File

@@ -1,7 +1,7 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, openai_embedding
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -35,7 +35,7 @@ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model=EMBEDDING_MODEL,
)

View File

@@ -1,6 +1,6 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete
from lightrag.llm.openai import gpt_4o_mini_complete
#########

View File

@@ -16,7 +16,7 @@
"import logging\n",
"import numpy as np\n",
"from lightrag import LightRAG, QueryParam\n",
"from lightrag.llm import openai_complete_if_cache, openai_embedding\n",
"from lightrag.llm.openai import openai_complete_if_cache, openai_embed\n",
"from lightrag.utils import EmbeddingFunc\n",
"import nest_asyncio"
]
@@ -74,7 +74,7 @@
"\n",
"\n",
"async def embedding_func(texts: list[str]) -> np.ndarray:\n",
" return await openai_embedding(\n",
" return await openai_embed(\n",
" texts,\n",
" model=\"ep-20241231173413-pgjmk\",\n",
" api_key=API,\n",
@@ -138,7 +138,7 @@
"\n",
"\n",
"async def embedding_func(texts: list[str]) -> np.ndarray:\n",
" return await openai_embedding(\n",
" return await openai_embed(\n",
" texts,\n",
" model=\"ep-20241231173413-pgjmk\",\n",
" api_key=API,\n",

View File

@@ -1,7 +1,7 @@
import os
import time
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
# Working directory and the directory path for text files
@@ -20,7 +20,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
func=lambda texts: ollama_embed(texts, embed_model="nomic-embed-text"),
),
)

View File

@@ -8,10 +8,6 @@ import time
import re
from typing import List, Dict, Any, Optional, Union
from lightrag import LightRAG, QueryParam
from lightrag.llm import lollms_model_complete, lollms_embed
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
@@ -720,6 +716,20 @@ def create_app(args):
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
if args.llm_binding_host == "lollms" or args.embedding_binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
if args.llm_binding_host == "ollama" or args.embedding_binding == "ollama":
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
if args.llm_binding_host == "openai" or args.embedding_binding == "openai":
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
if (
args.llm_binding_host == "azure_openai"
or args.embedding_binding == "azure_openai"
):
from lightrag.llm.azure_openai import (
azure_openai_complete_if_cache,
azure_openai_embed,
)
async def openai_alike_model_complete(
prompt,
@@ -773,13 +783,13 @@ def create_app(args):
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "ollama"
else azure_openai_embedding(
else azure_openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "azure_openai"
else openai_embedding(
else openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
api_key=args.embedding_binding_api_key,

View File

@@ -1,4 +1,3 @@
aioboto3
ascii_colors
fastapi
nano_vectordb

55
lightrag/exceptions.py Normal file
View File

@@ -0,0 +1,55 @@
import httpx
from typing import Literal
class APIStatusError(Exception):
"""Raised when an API response has a status code of 4xx or 5xx."""
response: httpx.Response
status_code: int
request_id: str | None
def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
super().__init__(message, response.request, body=body)
self.response = response
self.status_code = response.status_code
self.request_id = response.headers.get("x-request-id")
class APIConnectionError(Exception):
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
super().__init__(message, request, body=None)
class BadRequestError(APIStatusError):
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
class AuthenticationError(APIStatusError):
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
class PermissionDeniedError(APIStatusError):
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
class NotFoundError(APIStatusError):
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
class ConflictError(APIStatusError):
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
class UnprocessableEntityError(APIStatusError):
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
class RateLimitError(APIStatusError):
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)
class BadRequestError(APIStatusError):
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]

View File

@@ -1,7 +1,8 @@
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import aioredis
# aioredis is a depricated library, replaced with redis
from redis.asyncio import Redis
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
import json
@@ -11,7 +12,7 @@ import json
class RedisKVStorage(BaseKVStorage):
def __post_init__(self):
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
self._redis = aioredis.from_url(redis_url, decode_responses=True)
self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}")
async def all_keys(self) -> list[str]:

View File

@@ -6,10 +6,6 @@ from datetime import datetime
from functools import partial
from typing import Type, cast, Dict
from .llm import (
gpt_4o_mini_complete,
openai_embedding,
)
from .operate import (
chunking_by_token_size,
extract_entities,
@@ -154,12 +150,12 @@ class LightRAG:
)
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16

File diff suppressed because it is too large Load Diff

0
lightrag/llm/__init__.py Normal file
View File

View File

@@ -0,0 +1,188 @@
"""
Azure OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with aure openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.azure_openai import azure_openai_model_complete, azure_openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from openai import (
AsyncAzureOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
)
import numpy as np
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APIConnectionError)
),
)
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
api_version=None,
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hasattr(response, "__aiter__"):
async def inner():
async for chunk in response:
if len(chunk.choices) == 0:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def azure_openai_embed(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
api_version: str = None,
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

229
lightrag/llm/bedrock.py Normal file
View File

@@ -0,0 +1,229 @@
"""
Bedrock LLM Interface Module
==========================
This module provides interfaces for interacting with Bedrock's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- aioboto3, tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.bebrock import bebrock_model_complete, bebrock_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import copy
import os
import json
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aioboto3"):
pm.install("aioboto3")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import aioboto3
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from lightrag.utils import (
locate_json_string_body_from_string,
)
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
kwargs.pop("hashing_kv", None)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
return response["output"]["message"]["content"][0]["text"]
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embed(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)

187
lightrag/llm/hf.py Normal file
View File

@@ -0,0 +1,187 @@
"""
Hugging face LLM Interface Module
==========================
This module provides interfaces for interacting with Hugging face's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- transformers
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.hf import hf_model_complete, hf_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import copy
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("transformers"):
pm.install("transformers")
if not pm.is_installed("torch"):
pm.install("torch")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from transformers import AutoTokenizer, AutoModelForCausalLM
from functools import lru_cache
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from lightrag.utils import (
locate_json_string_body_from_string,
)
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
return hf_model, hf_tokenizer
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def hf_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
kwargs.pop("hashing_kv", None)
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
try:
ori_message = copy.deepcopy(messages)
if messages[0]["role"] == "system":
messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
len_message = len(ori_message)
for msgid in range(len_message):
input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
return response_text
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
result = await hf_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids.to(device)
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy()
else:
return embeddings.detach().cpu().numpy()

104
lightrag/llm/jina.py Normal file
View File

@@ -0,0 +1,104 @@
"""
Jina Embedding Interface Module
==========================
This module provides interfaces for interacting with jina system,
including embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added embedding generation
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.jina import jina_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
from functools import lru_cache
import numpy as np
from typing import Union
import aiohttp
async def fetch_data(url, headers, data):
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
response_json = await response.json()
data_list = response_json.get("data", [])
return data_list
async def jina_embed(
texts: list[str],
dimensions: int = 1024,
late_chunking: bool = False,
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["JINA_API_KEY"] = api_key
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
}
data = {
"model": "jina-embeddings-v3",
"normalized": True,
"embedding_type": "float",
"dimensions": f"{dimensions}",
"late_chunking": late_chunking,
"input": texts,
}
data_list = await fetch_data(url, headers, data)
return np.array([dp["embedding"] for dp in data_list])

190
lightrag/llm/lmdeploy.py Normal file
View File

@@ -0,0 +1,190 @@
"""
LMDeploy LLM Interface Module
==========================
This module provides interfaces for interacting with LMDeploy's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.lmdeploy import lmdeploy_model_complete, lmdeploy_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy[all]")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from functools import lru_cache
@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=(
ChatTemplateConfig(model_name=chat_template) if chat_template else None
),
log_level="WARNING",
)
return lmdeploy_pipe
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
kwargs.pop("hashing_kv", None)
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response
return response

222
lightrag/llm/lollms.py Normal file
View File

@@ -0,0 +1,222 @@
"""
LoLLMs (Lord of Large Language Models) Interface Module
=====================================================
This module provides the official interface for interacting with LoLLMs (Lord of Large Language and multimodal Systems),
a unified framework for AI model interaction and deployment.
LoLLMs is designed as a "one tool to rule them all" solution, providing seamless integration
with various AI models while maintaining high performance and user-friendly interfaces.
Author: ParisNeo
Created: 2024-01-24
License: Apache 2.0
Copyright (c) 2024 ParisNeo
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Version: 2.0.0
Change Log:
- 2.0.0 (2024-01-24):
* Added async support for model inference
* Implemented streaming capabilities
* Added embedding generation functionality
* Enhanced parameter handling
* Improved error handling and timeout management
Dependencies:
- aiohttp
- numpy
- Python >= 3.10
Features:
- Async text generation with streaming support
- Embedding generation
- Configurable model parameters
- System prompt and chat history support
- Timeout handling
- API key authentication
Usage:
from llm_interfaces.lollms import lollms_model_complete, lollms_embed
Project Repository: https://github.com/ParisNeo/lollms
Documentation: https://github.com/ParisNeo/lollms/docs
"""
__version__ = "1.0.0"
__author__ = "ParisNeo"
__status__ = "Production"
__project_url__ = "https://github.com/ParisNeo/lollms"
__doc_url__ = "https://github.com/ParisNeo/lollms/docs"
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aiohttp"):
pm.install("aiohttp")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import aiohttp
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from typing import Union, List
import numpy as np
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def lollms_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url="http://localhost:9600",
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Client implementation for lollms generation."""
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
if api_key
else {"Content-Type": "application/json"}
)
# Extract lollms specific parameters
request_data = {
"prompt": prompt,
"model_name": model,
"personality": kwargs.get("personality", -1),
"n_predict": kwargs.get("n_predict", None),
"stream": stream,
"temperature": kwargs.get("temperature", 0.1),
"top_k": kwargs.get("top_k", 50),
"top_p": kwargs.get("top_p", 0.95),
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
"repeat_last_n": kwargs.get("repeat_last_n", 40),
"seed": kwargs.get("seed", None),
"n_threads": kwargs.get("n_threads", 8),
}
# Prepare the full prompt including history
full_prompt = ""
if system_prompt:
full_prompt += f"{system_prompt}\n"
for msg in history_messages:
full_prompt += f"{msg['role']}: {msg['content']}\n"
full_prompt += prompt
request_data["prompt"] = full_prompt
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
if stream:
async def inner():
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
async for line in response.content:
yield line.decode().strip()
return inner()
else:
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
return await response.text()
async def lollms_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
# Extract and remove keyword_extraction from kwargs if present
keyword_extraction = kwargs.pop("keyword_extraction", None)
# Get model name from config
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
# If keyword extraction is needed, we might need to modify the prompt
# or add specific parameters for JSON output (if lollms supports it)
if keyword_extraction:
# Note: You might need to adjust this based on how lollms handles structured output
pass
return await lollms_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:
"""
Generate embeddings for a list of texts using lollms server.
Args:
texts: List of strings to embed
embed_model: Model name (not used directly as lollms uses configured vectorizer)
base_url: URL of the lollms server
**kwargs: Additional arguments passed to the request
Returns:
np.ndarray: Array of embeddings
"""
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": api_key}
if api_key
else {"Content-Type": "application/json"}
)
async with aiohttp.ClientSession(headers=headers) as session:
embeddings = []
for text in texts:
request_data = {"text": text}
async with session.post(
f"{base_url}/lollms_embed",
json=request_data,
) as response:
result = await response.json()
embeddings.append(result["vector"])
return np.array(embeddings)

View File

@@ -0,0 +1,112 @@
"""
OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.nvidia_openai import nvidia_openai_model_complete, nvidia_openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import os
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def nvidia_openai_embed(
texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None,
input_type: str = "passage", # query for retrieval, passage for embedding
trunc: str = "NONE", # NONE or START or END
encode: str = "float", # float or base64
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model,
input=texts,
encoding_format=encode,
extra_body={"input_type": input_type, "truncate": trunc},
)
return np.array([dp.embedding for dp in response.data])

155
lightrag/llm/ollama.py Normal file
View File

@@ -0,0 +1,155 @@
"""
Ollama LLM Interface Module
==========================
This module provides interfaces for interacting with Ollama's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- ollama
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.ollama_interface import ollama_model_complete, ollama_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("ollama"):
pm.install("ollama")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import ollama
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
import numpy as np
from typing import Union
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def ollama_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
if api_key
else {"Content-Type": "application/json"}
)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
"""cannot cache stream response"""
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
return inner()
else:
return response["message"]["content"]
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
Deprecated in favor of `embed`.
"""
embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts:
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": api_key}
if api_key
else {"Content-Type": "application/json"}
)
kwargs["headers"] = headers
ollama_client = ollama.Client(**kwargs)
data = ollama_client.embed(model=embed_model, input=texts)
return data["embeddings"]

232
lightrag/llm/openai.py Normal file
View File

@@ -0,0 +1,232 @@
"""
OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.openai import openai_model_complete, openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import os
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
from typing import Union
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 添加日志输出
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
logger.debug("Full context:")
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hasattr(response, "__aiter__"):
async def inner():
async for chunk in response:
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await openai_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def nvidia_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await openai_complete_if_cache(
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def openai_embed(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

View File

@@ -0,0 +1,121 @@
"""
SiliconCloud Embedding Interface Module
==========================
This module provides interfaces for interacting with SiliconCloud system,
including embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added embedding generation
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.siliconcloud import siliconcloud_model_complete, siliconcloud_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import copy
import os
import json
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy")
from openai import (
AsyncOpenAI,
AsyncAzureOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
from functools import lru_cache
import numpy as np
from typing import Union
import aiohttp
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def siliconcloud_embedding(
texts: list[str],
model: str = "netease-youdao/bce-embedding-base_v1",
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if "code" in content:
raise ValueError(content)
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)

250
lightrag/llm/zhipu.py Normal file
View File

@@ -0,0 +1,250 @@
"""
Zhipu LLM Interface Module
==========================
This module provides interfaces for interacting with LMDeploy's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.zhipu import zhipu_model_complete, zhipu_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import re
import json
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("zhipuai"):
pm.install("zhipuai")
from openai import (
AsyncOpenAI,
AsyncAzureOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
from functools import lru_cache
import numpy as np
from typing import Union, List, Optional, Dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_complete_if_cache(
prompt: Union[str, List[Dict[str, str]]],
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
**kwargs,
) -> str:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
messages = []
if not system_prompt:
system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Add debug logging
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs
kwargs = {
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
}
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
return response.choices[0].message.content
async def zhipu_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
# Add a system prompt to guide the model to return JSON format
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""
# Combine with existing system prompt if any
if system_prompt:
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
else:
system_prompt = extraction_prompt
try:
response = await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
match = re.search(r"\{[\s\S]*\}", response)
if match:
try:
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
pass
# If all parsing fails, log warning and return empty format
logger.warning(
f"Failed to parse keyword extraction response: {response}"
)
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
except Exception as e:
logger.error(f"Error during keyword extraction: {str(e)}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
else:
# For non-keyword-extraction, just return the raw response string
return await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_embedding(
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
) -> np.ndarray:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
# Convert single text to list if needed
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
try:
response = client.embeddings.create(model=model, input=[text], **kwargs)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
return np.array(embeddings)

View File

@@ -6,6 +6,8 @@ from dataclasses import dataclass
from typing import Any, Union, cast, Dict
import networkx as nx
import numpy as np
import pipmaster as pm
from nano_vectordb import NanoVectorDB
import time

6
lightrag/types.py Normal file
View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
from typing import List
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]

View File

@@ -535,7 +535,7 @@ class CacheData:
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
cache_type: str ="query"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):

View File

@@ -5,7 +5,7 @@ import numpy as np
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
## For Upstage API
@@ -25,7 +25,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),

View File

@@ -4,7 +4,7 @@ import json
import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
@@ -26,7 +26,7 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),

View File

@@ -1,22 +1,20 @@
accelerate
aioboto3
aiofiles
aiohttp
aioredis
redis
asyncpg
configparser
# database packages
graspologic
gremlinpython
hnswlib
nano-vectordb
neo4j
networkx
# TODO : Remove specific databases and move the installation to their corresponding files
# Use pipmaster for install if needed
numpy
ollama
openai
oracledb
pipmaster
psycopg-pool
@@ -33,14 +31,12 @@ python-dotenv
python-pptx
pyvis
setuptools
# lmdeploy[all]
sqlalchemy
tenacity
# LLM packages
tiktoken
torch
tqdm
transformers
xxhash