Merge pull request #462 from JasonGuoo/main

Supporting Zhipu AI API
This commit is contained in:
zrguo
2024-12-13 20:09:11 +08:00
committed by GitHub
3 changed files with 236 additions and 1 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@@ -0,0 +1,61 @@
import asyncio
import os
import inspect
import logging
from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam
from lightrag.llm import zhipu_complete, zhipu_embedding
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise Exception("Please set ZHIPU_API_KEY in your environment")
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=zhipu_complete,
llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here.
llm_model_max_async=4,
llm_model_max_token_size=32768,
embedding_func=EmbeddingFunc(
embedding_dim=2048, # Zhipu embedding-3 dimension
max_token_size=8192,
func=lambda texts: zhipu_embedding(
texts
),
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -2,9 +2,10 @@ import base64
import copy import copy
import json import json
import os import os
import re
import struct import struct
from functools import lru_cache from functools import lru_cache
from typing import List, Dict, Callable, Any, Union from typing import List, Dict, Callable, Any, Union, Optional
import aioboto3 import aioboto3
import aiohttp import aiohttp
import numpy as np import numpy as np
@@ -596,6 +597,179 @@ async def ollama_model_complete(
) )
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def zhipu_complete_if_cache(
prompt: Union[str, List[Dict[str, str]]],
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
**kwargs
) -> str:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
messages = []
if not system_prompt:
system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Add debug logging
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}
response = client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)
return response.choices[0].message.content
async def zhipu_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
# Add a system prompt to guide the model to return JSON format
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""
# Combine with existing system prompt if any
if system_prompt:
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
else:
system_prompt = extraction_prompt
try:
response = await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)
# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", [])
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
match = re.search(r"\{[\s\S]*\}", response)
if match:
try:
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", [])
)
except json.JSONDecodeError:
pass
# If all parsing fails, log warning and return empty format
logger.warning(f"Failed to parse keyword extraction response: {response}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
except Exception as e:
logger.error(f"Error during keyword extraction: {str(e)}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
else:
# For non-keyword-extraction, just return the raw response string
return await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def zhipu_embedding(
texts: list[str],
model: str = "embedding-3",
api_key: str = None,
**kwargs
) -> np.ndarray:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
# Convert single text to list if needed
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
try:
response = client.embeddings.create(
model=model,
input=[text],
**kwargs
)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
return np.array(embeddings)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),