61
examples/lightrag_zhipu_demo.py
Normal file
61
examples/lightrag_zhipu_demo.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import zhipu_complete, zhipu_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
api_key = os.environ.get("ZHIPUAI_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise Exception("Please set ZHIPU_API_KEY in your environment")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=zhipu_complete,
|
||||||
|
llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here.
|
||||||
|
llm_model_max_async=4,
|
||||||
|
llm_model_max_token_size=32768,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=2048, # Zhipu embedding-3 dimension
|
||||||
|
max_token_size=8192,
|
||||||
|
func=lambda texts: zhipu_embedding(
|
||||||
|
texts
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform local search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform global search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform hybrid search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
176
lightrag/llm.py
176
lightrag/llm.py
@@ -2,9 +2,10 @@ import base64
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import struct
|
import struct
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Dict, Callable, Any, Union
|
from typing import List, Dict, Callable, Any, Union, Optional
|
||||||
import aioboto3
|
import aioboto3
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -596,6 +597,179 @@ async def ollama_model_complete(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
|
)
|
||||||
|
async def zhipu_complete_if_cache(
|
||||||
|
prompt: Union[str, List[Dict[str, str]]],
|
||||||
|
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
history_messages: List[Dict[str, str]] = [],
|
||||||
|
**kwargs
|
||||||
|
) -> str:
|
||||||
|
# dynamically load ZhipuAI
|
||||||
|
try:
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
client = ZhipuAI(api_key=api_key)
|
||||||
|
else:
|
||||||
|
# please set ZHIPUAI_API_KEY in your environment
|
||||||
|
# os.environ["ZHIPUAI_API_KEY"]
|
||||||
|
client = ZhipuAI()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if not system_prompt:
|
||||||
|
system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"
|
||||||
|
|
||||||
|
# Add system prompt if provided
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.extend(history_messages)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
# Add debug logging
|
||||||
|
logger.debug("===== Query Input to LLM =====")
|
||||||
|
logger.debug(f"Query: {prompt}")
|
||||||
|
logger.debug(f"System prompt: {system_prompt}")
|
||||||
|
|
||||||
|
# Remove unsupported kwargs
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
async def zhipu_complete(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
):
|
||||||
|
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
|
||||||
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
|
|
||||||
|
if keyword_extraction:
|
||||||
|
# Add a system prompt to guide the model to return JSON format
|
||||||
|
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
|
||||||
|
Please analyze the content and extract two types of keywords:
|
||||||
|
1. High-level keywords: Important concepts and main themes
|
||||||
|
2. Low-level keywords: Specific details and supporting elements
|
||||||
|
|
||||||
|
Return your response in this exact JSON format:
|
||||||
|
{
|
||||||
|
"high_level_keywords": ["keyword1", "keyword2"],
|
||||||
|
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
||||||
|
}
|
||||||
|
|
||||||
|
Only return the JSON, no other text."""
|
||||||
|
|
||||||
|
# Combine with existing system prompt if any
|
||||||
|
if system_prompt:
|
||||||
|
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
|
||||||
|
else:
|
||||||
|
system_prompt = extraction_prompt
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await zhipu_complete_if_cache(
|
||||||
|
prompt=prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to parse as JSON
|
||||||
|
try:
|
||||||
|
data = json.loads(response)
|
||||||
|
return GPTKeywordExtractionFormat(
|
||||||
|
high_level_keywords=data.get("high_level_keywords", []),
|
||||||
|
low_level_keywords=data.get("low_level_keywords", [])
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If direct JSON parsing fails, try to extract JSON from text
|
||||||
|
match = re.search(r"\{[\s\S]*\}", response)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
data = json.loads(match.group())
|
||||||
|
return GPTKeywordExtractionFormat(
|
||||||
|
high_level_keywords=data.get("high_level_keywords", []),
|
||||||
|
low_level_keywords=data.get("low_level_keywords", [])
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If all parsing fails, log warning and return empty format
|
||||||
|
logger.warning(f"Failed to parse keyword extraction response: {response}")
|
||||||
|
return GPTKeywordExtractionFormat(
|
||||||
|
high_level_keywords=[], low_level_keywords=[]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during keyword extraction: {str(e)}")
|
||||||
|
return GPTKeywordExtractionFormat(
|
||||||
|
high_level_keywords=[], low_level_keywords=[]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For non-keyword-extraction, just return the raw response string
|
||||||
|
return await zhipu_complete_if_cache(
|
||||||
|
prompt=prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
|
)
|
||||||
|
async def zhipu_embedding(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "embedding-3",
|
||||||
|
api_key: str = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
|
||||||
|
# dynamically load ZhipuAI
|
||||||
|
try:
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
||||||
|
if api_key:
|
||||||
|
client = ZhipuAI(api_key=api_key)
|
||||||
|
else:
|
||||||
|
# please set ZHIPUAI_API_KEY in your environment
|
||||||
|
# os.environ["ZHIPUAI_API_KEY"]
|
||||||
|
client = ZhipuAI()
|
||||||
|
|
||||||
|
# Convert single text to list if needed
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
try:
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=model,
|
||||||
|
input=[text],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
embeddings.append(response.data[0].embedding)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
||||||
|
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
|
Reference in New Issue
Block a user