From 1ef973c7fc0d9126954dfdc31ecd53643b6c18cd Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 22 Oct 2024 15:16:57 +0800 Subject: [PATCH] feat(examples): support siliconcloud free API --- README.md | 1 + examples/lightrag_siliconcloud_demo.py | 79 ++++++++++++++++++++++++++ lightrag/llm.py | 48 +++++++++++++++- requirements.txt | 3 +- 4 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 examples/lightrag_siliconcloud_demo.py diff --git a/README.md b/README.md index 76535d19..87335f1f 100644 --- a/README.md +++ b/README.md @@ -629,6 +629,7 @@ def extract_queries(file_path): │ ├── lightrag_ollama_demo.py │ ├── lightrag_openai_compatible_demo.py │ ├── lightrag_openai_demo.py +│ ├── lightrag_siliconcloud_demo.py │ └── vram_management_demo.py ├── lightrag │ ├── __init__.py diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py new file mode 100644 index 00000000..e3f5e67e --- /dev/null +++ b/examples/lightrag_siliconcloud_demo.py @@ -0,0 +1,79 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "Qwen/Qwen2.5-7B-Instruct", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.siliconflow.cn/v1/", + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + model="netease-youdao/bce-embedding-base_v1", + api_key=os.getenv("UPSTAGE_API_KEY"), + max_token_size=int(512 * 1.5) + ) + + +# function test +async def test_funcs(): + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("embedding_func: ", result) + + +asyncio.run(test_funcs()) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=768, max_token_size=512, func=embedding_func + ), +) + + +with open("./book.txt") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/lightrag/llm.py b/lightrag/llm.py index be801e0c..06d75d01 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -2,8 +2,11 @@ import os import copy import json import aioboto3 +import aiohttp import numpy as np import ollama +import base64 +import struct from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from tenacity import ( retry, @@ -312,7 +315,7 @@ async def ollama_model_complete( @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), + wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) async def openai_embedding( @@ -332,6 +335,49 @@ async def openai_embedding( ) return np.array([dp.embedding for dp in response.data]) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def siliconcloud_embedding( + texts: list[str], + model: str = "netease-youdao/bce-embedding-base_v1", + base_url: str = "https://api.siliconflow.cn/v1/embeddings", + max_token_size: int = 512, + api_key: str = None, +) -> np.ndarray: + if api_key and not api_key.startswith('Bearer '): + api_key = 'Bearer ' + api_key + + headers = { + "Authorization": api_key, + "Content-Type": "application/json" + } + + truncate_texts = [text[0:max_token_size] for text in texts] + + payload = { + "model": model, + "input": truncate_texts, + "encoding_format": "base64" + } + + base64_strings = [] + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=payload) as response: + content = await response.json() + if 'code' in content: + raise ValueError(content) + base64_strings = [item['embedding'] for item in content['data']] + + embeddings = [] + for string in base64_strings: + decode_bytes = base64.b64decode(string) + n = len(decode_bytes) // 4 + float_array = struct.unpack('<' + 'f' * n, decode_bytes) + embeddings.append(float_array) + return np.array(embeddings) # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry( diff --git a/requirements.txt b/requirements.txt index 9cc5b7e9..5b3396fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ tiktoken torch transformers xxhash -pyvis \ No newline at end of file +pyvis +aiohttp \ No newline at end of file