Merge pull request #96 from tpoisonooo/support-siliconcloud
feat(examples): support siliconcloud free API
This commit is contained in:
@@ -629,6 +629,7 @@ def extract_queries(file_path):
|
|||||||
│ ├── lightrag_ollama_demo.py
|
│ ├── lightrag_ollama_demo.py
|
||||||
│ ├── lightrag_openai_compatible_demo.py
|
│ ├── lightrag_openai_compatible_demo.py
|
||||||
│ ├── lightrag_openai_demo.py
|
│ ├── lightrag_openai_demo.py
|
||||||
|
│ ├── lightrag_siliconcloud_demo.py
|
||||||
│ └── vram_management_demo.py
|
│ └── vram_management_demo.py
|
||||||
├── lightrag
|
├── lightrag
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
|
79
examples/lightrag_siliconcloud_demo.py
Normal file
79
examples/lightrag_siliconcloud_demo.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.siliconflow.cn/v1/",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await siliconcloud_embedding(
|
||||||
|
texts,
|
||||||
|
model="netease-youdao/bce-embedding-base_v1",
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
max_token_size=int(512 * 1.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# function test
|
||||||
|
async def test_funcs():
|
||||||
|
result = await llm_model_func("How are you?")
|
||||||
|
print("llm_model_func: ", result)
|
||||||
|
|
||||||
|
result = await embedding_func(["How are you?"])
|
||||||
|
print("embedding_func: ", result)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=768, max_token_size=512, func=embedding_func
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
with open("./book.txt") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform local search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform global search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform hybrid search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
@@ -2,8 +2,11 @@ import os
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import aioboto3
|
import aioboto3
|
||||||
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ollama
|
import ollama
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
@@ -312,7 +315,7 @@ async def ollama_model_complete(
|
|||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
)
|
)
|
||||||
async def openai_embedding(
|
async def openai_embedding(
|
||||||
@@ -332,6 +335,49 @@ async def openai_embedding(
|
|||||||
)
|
)
|
||||||
return np.array([dp.embedding for dp in response.data])
|
return np.array([dp.embedding for dp in response.data])
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
|
)
|
||||||
|
async def siliconcloud_embedding(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "netease-youdao/bce-embedding-base_v1",
|
||||||
|
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
||||||
|
max_token_size: int = 512,
|
||||||
|
api_key: str = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if api_key and not api_key.startswith('Bearer '):
|
||||||
|
api_key = 'Bearer ' + api_key
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": api_key,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
truncate_texts = [text[0:max_token_size] for text in texts]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"input": truncate_texts,
|
||||||
|
"encoding_format": "base64"
|
||||||
|
}
|
||||||
|
|
||||||
|
base64_strings = []
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||||
|
content = await response.json()
|
||||||
|
if 'code' in content:
|
||||||
|
raise ValueError(content)
|
||||||
|
base64_strings = [item['embedding'] for item in content['data']]
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for string in base64_strings:
|
||||||
|
decode_bytes = base64.b64decode(string)
|
||||||
|
n = len(decode_bytes) // 4
|
||||||
|
float_array = struct.unpack('<' + 'f' * n, decode_bytes)
|
||||||
|
embeddings.append(float_array)
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
# @retry(
|
# @retry(
|
||||||
|
@@ -12,3 +12,4 @@ torch
|
|||||||
transformers
|
transformers
|
||||||
xxhash
|
xxhash
|
||||||
pyvis
|
pyvis
|
||||||
|
aiohttp
|
Reference in New Issue
Block a user