Merge pull request #96 from tpoisonooo/support-siliconcloud

feat(examples): support siliconcloud free API
This commit is contained in:
zrguo
2024-10-23 11:02:42 +08:00
committed by GitHub
4 changed files with 129 additions and 2 deletions

View File

@@ -629,6 +629,7 @@ def extract_queries(file_path):
├── lightrag_ollama_demo.py ├── lightrag_ollama_demo.py
├── lightrag_openai_compatible_demo.py ├── lightrag_openai_compatible_demo.py
├── lightrag_openai_demo.py ├── lightrag_openai_demo.py
├── lightrag_siliconcloud_demo.py
└── vram_management_demo.py └── vram_management_demo.py
├── lightrag ├── lightrag
├── __init__.py ├── __init__.py

View File

@@ -0,0 +1,79 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"Qwen/Qwen2.5-7B-Instruct",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.siliconflow.cn/v1/",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("UPSTAGE_API_KEY"),
max_token_size=int(512 * 1.5)
)
# function test
async def test_funcs():
result = await llm_model_func("How are you?")
print("llm_model_func: ", result)
result = await embedding_func(["How are you?"])
print("embedding_func: ", result)
asyncio.run(test_funcs())
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=768, max_token_size=512, func=embedding_func
),
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -2,8 +2,11 @@ import os
import copy import copy
import json import json
import aioboto3 import aioboto3
import aiohttp
import numpy as np import numpy as np
import ollama import ollama
import base64
import struct
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
from tenacity import ( from tenacity import (
retry, retry,
@@ -312,7 +315,7 @@ async def ollama_model_complete(
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
) )
async def openai_embedding( async def openai_embedding(
@@ -332,6 +335,49 @@ async def openai_embedding(
) )
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def siliconcloud_embedding(
texts: list[str],
model: str = "netease-youdao/bce-embedding-base_v1",
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith('Bearer '):
api_key = 'Bearer ' + api_key
headers = {
"Authorization": api_key,
"Content-Type": "application/json"
}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {
"model": model,
"input": truncate_texts,
"encoding_format": "base64"
}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if 'code' in content:
raise ValueError(content)
base64_strings = [item['embedding'] for item in content['data']]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack('<' + 'f' * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry( # @retry(

View File

@@ -11,4 +11,5 @@ tiktoken
torch torch
transformers transformers
xxhash xxhash
pyvis pyvis
aiohttp