From 6a0e9c6c7745972b2d117e3c75de5485bdb13139 Mon Sep 17 00:00:00 2001 From: Jason Guo <134677535+JasonGuoo@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:18:33 +0800 Subject: [PATCH 1/2] Modify the chat_complete method to support keywords extraction. --- examples/lightrag_zhipu_demo.py | 61 +++++++++++ lightrag/llm.py | 175 +++++++++++++++++++++++++++++++- 2 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 examples/lightrag_zhipu_demo.py diff --git a/examples/lightrag_zhipu_demo.py b/examples/lightrag_zhipu_demo.py new file mode 100644 index 00000000..bcade616 --- /dev/null +++ b/examples/lightrag_zhipu_demo.py @@ -0,0 +1,61 @@ +import asyncio +import os +import inspect +import logging + +from dotenv import load_dotenv + +from lightrag import LightRAG, QueryParam +from lightrag.llm import zhipu_complete, zhipu_embedding +from lightrag.utils import EmbeddingFunc + +WORKING_DIR = "./dickens" + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +api_key = os.environ.get("ZHIPUAI_API_KEY") +if api_key is None: + raise Exception("Please set ZHIPU_API_KEY in your environment") + + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=zhipu_complete, + llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here. + llm_model_max_async=4, + llm_model_max_token_size=32768, + embedding_func=EmbeddingFunc( + embedding_dim=2048, # Zhipu embedding-3 dimension + max_token_size=8192, + func=lambda texts: zhipu_embedding( + texts + ), + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) \ No newline at end of file diff --git a/lightrag/llm.py b/lightrag/llm.py index 636f03cb..4df835b0 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,7 +4,7 @@ import json import os import struct from functools import lru_cache -from typing import List, Dict, Callable, Any, Union +from typing import List, Dict, Callable, Any, Union, Optional import aioboto3 import aiohttp import numpy as np @@ -596,6 +596,179 @@ async def ollama_model_complete( ) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def zhipu_complete_if_cache( + prompt: Union[str, List[Dict[str, str]]], + model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series + api_key: Optional[str] = None, + system_prompt: Optional[str] = None, + history_messages: List[Dict[str, str]] = [], + **kwargs +) -> str: + # dynamically load ZhipuAI + try: + from zhipuai import ZhipuAI + except ImportError: + raise ImportError("Please install zhipuai before initialize zhipuai backend.") + + if api_key: + client = ZhipuAI(api_key=api_key) + else: + # please set ZHIPUAI_API_KEY in your environment + # os.environ["ZHIPUAI_API_KEY"] + client = ZhipuAI() + + messages = [] + + if not system_prompt: + system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。" + + # Add system prompt if provided + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Add debug logging + logger.debug("===== Query Input to LLM =====") + logger.debug(f"Query: {prompt}") + logger.debug(f"System prompt: {system_prompt}") + + # Remove unsupported kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']} + + response = client.chat.completions.create( + model=model, + messages=messages, + **kwargs + ) + + return response.choices[0].message.content + + +async def zhipu_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +): + # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache + keyword_extraction = kwargs.pop("keyword_extraction", None) + + if keyword_extraction: + # Add a system prompt to guide the model to return JSON format + extraction_prompt = """You are a helpful assistant that extracts keywords from text. + Please analyze the content and extract two types of keywords: + 1. High-level keywords: Important concepts and main themes + 2. Low-level keywords: Specific details and supporting elements + + Return your response in this exact JSON format: + { + "high_level_keywords": ["keyword1", "keyword2"], + "low_level_keywords": ["keyword1", "keyword2", "keyword3"] + } + + Only return the JSON, no other text.""" + + # Combine with existing system prompt if any + if system_prompt: + system_prompt = f"{system_prompt}\n\n{extraction_prompt}" + else: + system_prompt = extraction_prompt + + try: + response = await zhipu_complete_if_cache( + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs + ) + + # Try to parse as JSON + try: + data = json.loads(response) + return GPTKeywordExtractionFormat( + high_level_keywords=data.get("high_level_keywords", []), + low_level_keywords=data.get("low_level_keywords", []) + ) + except json.JSONDecodeError: + # If direct JSON parsing fails, try to extract JSON from text + match = re.search(r"\{[\s\S]*\}", response) + if match: + try: + data = json.loads(match.group()) + return GPTKeywordExtractionFormat( + high_level_keywords=data.get("high_level_keywords", []), + low_level_keywords=data.get("low_level_keywords", []) + ) + except json.JSONDecodeError: + pass + + # If all parsing fails, log warning and return empty format + logger.warning(f"Failed to parse keyword extraction response: {response}") + return GPTKeywordExtractionFormat( + high_level_keywords=[], low_level_keywords=[] + ) + except Exception as e: + logger.error(f"Error during keyword extraction: {str(e)}") + return GPTKeywordExtractionFormat( + high_level_keywords=[], low_level_keywords=[] + ) + else: + # For non-keyword-extraction, just return the raw response string + return await zhipu_complete_if_cache( + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs + ) + + +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def zhipu_embedding( + texts: list[str], + model: str = "embedding-3", + api_key: str = None, + **kwargs +) -> np.ndarray: + +# dynamically load ZhipuAI + try: + from zhipuai import ZhipuAI + except ImportError: + raise ImportError("Please install zhipuai before initialize zhipuai backend.") + if api_key: + client = ZhipuAI(api_key=api_key) + else: + # please set ZHIPUAI_API_KEY in your environment + # os.environ["ZHIPUAI_API_KEY"] + client = ZhipuAI() + + # Convert single text to list if needed + if isinstance(texts, str): + texts = [texts] + + embeddings = [] + for text in texts: + try: + response = client.embeddings.create( + model=model, + input=[text], + **kwargs + ) + embeddings.append(response.data[0].embedding) + except Exception as e: + raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}") + + return np.array(embeddings) + + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), From e64cf5068f1282138ada845c390560f08abdc02f Mon Sep 17 00:00:00 2001 From: Jason Guo <134677535+JasonGuoo@users.noreply.github.com> Date: Fri, 13 Dec 2024 19:57:25 +0800 Subject: [PATCH 2/2] Fix import --- .DS_Store | Bin 8196 -> 8196 bytes lightrag/llm.py | 1 + 2 files changed, 1 insertion(+) diff --git a/.DS_Store b/.DS_Store index 651e36edce1916e150c6c409938c130c393b622b..7489d923a9e7375a0aadb3d45e73e969d67f761d 100644 GIT binary patch literal 8196 zcmeHMTWl0n7(U;$(3#;jz3G5m*ica{3$&KdV$p0{3Y034-f7Wwc4wqxr!!?|b_)_> zjrt@S6K~N(LcBcqAc-1CG-@>AfvEApG)7G1#Y7+U!Ap#qVDLY4W(jTK!GstR<|OC* z^Pm5G=klNLpR;F{F@}zU*~FO67?UY;sV$=73W?i!o|mMkrko@Q&zQwr=4Yq;$*eQ4 zbzBj9Aof7)f!G7F2VxIg4<4X1nJAnagw}r-ck(9aMx@ z0FwL)5GH!2dq9XuMlv1AX(2->^(k%-h@gnB7!Yu>CwX&{=}1lsDc}qO&Jc}^=!SxD zbiyUW<_u{e<1+R@?1A|n;NerjOlGqj8^8Si?q$jHes&XC>4t; z?(ZA(ruqUSDEJ*lew1tdp66ss&qj98wx%*#O{eDuhV5GU5MY}YHQ7I8yQVkQ;T2rd z53zm1peVAE(Uu=M(%iHqm27TlIhslyX=-e3P9-Zv$A9Jm*x1y&aA_iC0!^j>zim7k zjomrV-J7?#eJJNSg$dWsXq7p`$rT(Uu-YA`&wj-6hXdgYle!pZ)Y8AV9GXw zv3-W0vs}}54~&ol-N1gxw(?3PPAL7ZF=5SUOP4LV>1JIYE@bV41v@yb)yuNzt4?33 zXs6`;CXIieXL6(&TJ$=7wJeYF`p&#XrKncd=(j7HA@}>V1kpu-%_T?WSp&(aQC` zdatZ~MHtO?i#Ac1wpTD&u1`0Joe14mpKcVILI=}wwt~?&+HHOC;N;24GCpffQ6_=9THfH1~g+6Hlqbw(S_~UiEix1 z5QZ^=Q5bj_)`S8ZY7{yo^`y7T(2sIEB;r6rbU9e1Y@$5x?R${EiFw z2NxwKRZBHeLRu}Y6VsMTkuf}*UQ&wmIRQ_RK_BeG$T)bV$VhD8vGe+fjB|^5EL2ui z*VL}$0kXL*eThbbcqOh~&f$=@v&bIOHr(gAe4~h2zTy@&F+jYw4$Y0UQUq!hoWb%8 zZ@nt1CbR{_e4*aDhSY_`dZFIYkVt3}1)ESeY)q(H8HJJ(Vw)RCMWnQ}s;X8_K_zTk z?toeTRL(K4E9s3{mw5_7nSsU4)FKs3BtBhP7CS6xI_tTX7fK zLW1rfdhS9G`Y}ip-A5$N!NLKI!A2ed3L#M+3yJz9p24$t4$tES9LEW~hS%{1-o!~F z?T2&dxN9~Y@zH!b7W?w9dqku@F6llA%T(YtT5LVq5sg_1q+AWq_irwBp*V&zcPMFS=o!MQa zsWvgj#0y@ckD3@0AB+-sFd-TrOuXQW8YD!c4@P-1QC>9aoBugymP-o{B*sgalbrLP z|28xKKi_}Oo>>6E?y}hi5CZ^}I;Ywq8m`f}p7(i82_^MJk^BKHaKSHB{i%X8&pHxB z8Hh3vWgyBxlz}J%|Ah?Dp3R$joqb15pO%XMmSa1DIe#5e{E@e)kHbIbE1?Y(Ge+&!bXTzhGg5+Mq5{AIeR5 z)m&f%Wxv}f?dRN}=Q)Mi+$ap&RyD6J@A2HguwAPp9Bk7flLI5RYkJjguk4z>@Xbku z%2bxuV$;)`+B#E-dp4zKQin~Z>3UDgKrmbkpm9!M&&-d? z@hd$AheuV{Qj6;AT2zcJ(uYSz$MTw1_N@WSF#@|{WsSh{^V$-h0^C{j+IK?sLo_ML0Ib&1x*ma%cliyPPnZL2h9_(jV#ZTHYn)i#5Pk&5lx z1;-k6jY*3&VUC?^X^l1NaeZY?B@Hh-s#;PVEf?$~WjlCO++($#UJ0dxE>NV}a<4Tc`gx;^JWBh{KP17>ZHMYiD9vslgL$lqAO&F$EAfc?C zRaG^a*&v}LGMmh#CAWMIzW8pXLkhOS4j8A(aiI#&!gFvOUWGT|6ub`~!e{U~d;w?S zEBFpBz)$ci{06_nABcDhwqgw9xB_p-G`3+oc3}p$V-IF=FAiW1hj1A4XyHMeKpRVV z45x7hAH%2c1$+@-##it)d5BMWqz@PCKyoi4)iqfpKDltV@ z;!09kuh7>%-jrI+j7z>YD%~`jvl39v(3d=3Gs=-`)l7WH1AROHm74ixv8a-z%a%8< zPOiVRy`w902~84yHN;(sV*%c)ATPix-t`>Raro$QwsKXio-S9CtgInIy>(RO(A643 z>)Y0-gnDg3$gW*S>~3wLWaArCDpi=K2$sduZH($js|(qUn~6oqu+A-vDoU#l*{$0c zFLDh8ACLEB=}Kp`zX7hV!&`6?PQyn8*Yof#d=D21qCgPUutmUh1+K!ixDMCj229~* z?7(}mlK`8=Ufh9wxQpOANPr!|F)UyaO#-clQ+Nb@EE8lO$60&=pTy(%Jc0M65WH{T z+Y)$};JB*>L`m_3`Eab{O0MI%he%Tc@zqWdcS9Mdlh!eQ{_nm1_x~HZW@H>?Aj-i1 zD+5@c&1L!af4x;zAzgy{Bh-20cC(TS7aF(;my6>#(Mx|A(mWCBG9jInlr+@-@ecvf R{(r^K)oA}0AiN2OzW_aM8~y+Q diff --git a/lightrag/llm.py b/lightrag/llm.py index 4df835b0..591b5dc9 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -2,6 +2,7 @@ import base64 import copy import json import os +import re import struct from functools import lru_cache from typing import List, Dict, Callable, Any, Union, Optional