Update __version__

This commit is contained in:
LarFii
2024-12-13 20:15:49 +08:00
parent 9cac3b0ed7
commit b7a2d336e6
5 changed files with 29 additions and 39 deletions

View File

@@ -1,9 +1,6 @@
import asyncio
import os import os
import inspect
import logging import logging
from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import zhipu_complete, zhipu_embedding from lightrag.llm import zhipu_complete, zhipu_embedding
@@ -21,7 +18,6 @@ if api_key is None:
raise Exception("Please set ZHIPU_API_KEY in your environment") raise Exception("Please set ZHIPU_API_KEY in your environment")
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=zhipu_complete, llm_model_func=zhipu_complete,
@@ -31,9 +27,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=2048, # Zhipu embedding-3 dimension embedding_dim=2048, # Zhipu embedding-3 dimension
max_token_size=8192, max_token_size=8192,
func=lambda texts: zhipu_embedding( func=lambda texts: zhipu_embedding(texts),
texts
),
), ),
) )

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.0.5" __version__ = "1.0.6"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -63,7 +63,9 @@ class MilvusVectorDBStorge(BaseVectorStorage):
return result return result
embedding_tasks = [wrapped_task(batch) for batch in batches] embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch") pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks) embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)

View File

@@ -608,7 +608,7 @@ async def zhipu_complete_if_cache(
api_key: Optional[str] = None, api_key: Optional[str] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [], history_messages: List[Dict[str, str]] = [],
**kwargs **kwargs,
) -> str: ) -> str:
# dynamically load ZhipuAI # dynamically load ZhipuAI
try: try:
@@ -640,13 +640,11 @@ async def zhipu_complete_if_cache(
logger.debug(f"System prompt: {system_prompt}") logger.debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs # Remove unsupported kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']} kwargs = {
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
}
response = client.chat.completions.create( response = client.chat.completions.create(model=model, messages=messages, **kwargs)
model=model,
messages=messages,
**kwargs
)
return response.choices[0].message.content return response.choices[0].message.content
@@ -683,7 +681,7 @@ async def zhipu_complete(
prompt=prompt, prompt=prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs **kwargs,
) )
# Try to parse as JSON # Try to parse as JSON
@@ -691,7 +689,7 @@ async def zhipu_complete(
data = json.loads(response) data = json.loads(response)
return GPTKeywordExtractionFormat( return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []), high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []) low_level_keywords=data.get("low_level_keywords", []),
) )
except json.JSONDecodeError: except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text # If direct JSON parsing fails, try to extract JSON from text
@@ -701,13 +699,15 @@ async def zhipu_complete(
data = json.loads(match.group()) data = json.loads(match.group())
return GPTKeywordExtractionFormat( return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []), high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []) low_level_keywords=data.get("low_level_keywords", []),
) )
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
# If all parsing fails, log warning and return empty format # If all parsing fails, log warning and return empty format
logger.warning(f"Failed to parse keyword extraction response: {response}") logger.warning(
f"Failed to parse keyword extraction response: {response}"
)
return GPTKeywordExtractionFormat( return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[] high_level_keywords=[], low_level_keywords=[]
) )
@@ -722,7 +722,7 @@ async def zhipu_complete(
prompt=prompt, prompt=prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
**kwargs **kwargs,
) )
@@ -733,13 +733,9 @@ async def zhipu_complete(
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
) )
async def zhipu_embedding( async def zhipu_embedding(
texts: list[str], texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
model: str = "embedding-3",
api_key: str = None,
**kwargs
) -> np.ndarray: ) -> np.ndarray:
# dynamically load ZhipuAI
# dynamically load ZhipuAI
try: try:
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
except ImportError: except ImportError:
@@ -758,11 +754,7 @@ async def zhipu_embedding(
embeddings = [] embeddings = []
for text in texts: for text in texts:
try: try:
response = client.embeddings.create( response = client.embeddings.create(model=model, input=[text], **kwargs)
model=model,
input=[text],
**kwargs
)
embeddings.append(response.data[0].embedding) embeddings.append(response.data[0].embedding)
except Exception as e: except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}") raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")

View File

@@ -103,7 +103,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
return result return result
embedding_tasks = [wrapped_task(batch) for batch in batches] embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch") pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks) embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)