Merge branch 'main' into upload-error

This commit is contained in:
yangdx
2025-03-28 15:34:38 +08:00
6 changed files with 341 additions and 9 deletions

View File

@@ -0,0 +1,153 @@
# pip install -q -U google-genai to use gemini as a client
import os
import asyncio
import numpy as np
import nest_asyncio
from google import genai
from google.genai import types
from dotenv import load_dotenv
from lightrag.utils import EmbeddingFunc
from lightrag import LightRAG, QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.utils import setup_logger
from lightrag.utils import TokenTracker
setup_logger("lightrag", level="DEBUG")
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
token_tracker = TokenTracker()
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 1. Initialize the GenAI Client with your Gemini API Key
client = genai.Client(api_key=gemini_api_key)
# 2. Combine prompts: system prompt, history, and user prompt
if history_messages is None:
history_messages = []
combined_prompt = ""
if system_prompt:
combined_prompt += f"{system_prompt}\n"
for msg in history_messages:
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
combined_prompt += f"{msg['role']}: {msg['content']}\n"
# Finally, add the new user prompt
combined_prompt += f"user: {prompt}"
# 3. Call the Gemini model
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=[combined_prompt],
config=types.GenerateContentConfig(
max_output_tokens=5000, temperature=0, top_k=10
),
)
# 4. Get token counts with null safety
usage = getattr(response, "usage_metadata", None)
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
total_tokens = getattr(usage, "total_token_count", 0) or (
prompt_tokens + completion_tokens
)
token_counts = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
token_tracker.add_usage(token_counts)
# 5. Return the response text
return response.text
async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
model="BAAI/bge-m3",
api_key=siliconflow_api_key,
max_token_size=512,
)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
entity_extract_max_gleaning=1,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=True,
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
# Reset tracker before processing queries
token_tracker.reset()
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
# Display final token usage after main query
print("Token usage:", token_tracker.get_usage())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,114 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.utils import EmbeddingFunc
from lightrag.utils import TokenTracker
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status
from dotenv import load_dotenv
load_dotenv()
token_tracker = TokenTracker()
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
"Qwen/Qwen2.5-7B-Instruct",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("SILICONFLOW_API_KEY"),
base_url="https://api.siliconflow.cn/v1/",
token_tracker=token_tracker,
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
model="BAAI/bge-m3",
api_key=os.getenv("SILICONFLOW_API_KEY"),
max_token_size=512,
)
# function test
async def test_funcs():
# Reset tracker before processing queries
token_tracker.reset()
result = await llm_model_func("How are you?")
print("llm_model_func: ", result)
# Display final token usage after main query
print("Token usage:", token_tracker.get_usage())
asyncio.run(test_funcs())
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024, max_token_size=512, func=embedding_func
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
# Reset tracker before processing queries
token_tracker.reset()
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
# Display final token usage after main query
print("Token usage:", token_tracker.get_usage())
if __name__ == "__main__":
main()

View File

@@ -58,6 +58,7 @@ async def openai_complete_if_cache(
history_messages: list[dict[str, Any]] | None = None,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
**kwargs: Any,
) -> str:
if history_messages is None:
@@ -154,6 +155,15 @@ async def openai_complete_if_cache(
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
return content

View File

@@ -1038,7 +1038,7 @@ async def mix_kg_vector_query(
# Include time information in content
formatted_chunks = []
for c in maybe_trun_chunks:
chunk_text = c["content"]
chunk_text = "File path: " + c["file_path"] + "\n" + c["content"]
if c["created_at"]:
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
formatted_chunks.append(chunk_text)
@@ -1334,9 +1334,9 @@ async def _get_node_data(
)
relations_context = list_of_list_to_csv(relations_section_list)
text_units_section_list = [["id", "content"]]
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_section_list.append([i, t["content"], t["file_path"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context, relations_context, text_units_context
@@ -1597,9 +1597,9 @@ async def _get_edge_data(
)
entities_context = list_of_list_to_csv(entites_section_list)
text_units_section_list = [["id", "content"]]
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_section_list.append([i, t["content"], t["file_path"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context, relations_context, text_units_context
@@ -1785,7 +1785,12 @@ async def naive_query(
f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
section = "\n--New Chunk--\n".join(
[
"File path: " + c["file_path"] + "\n" + c["content"]
for c in maybe_trun_chunks
]
)
if query_param.only_need_context:
return section

View File

@@ -222,7 +222,7 @@ When handling relationships with timestamps:
- Use markdown formatting with appropriate section headings
- Please respond in the same language as the user's question.
- Ensure the response maintains continuity with the conversation history.
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
- If you don't know the answer, just say so.
- Do not make anything up. Do not include information not provided by the Knowledge Base."""
@@ -320,7 +320,7 @@ When handling content with timestamps:
- Use markdown formatting with appropriate section headings
- Please respond in the same language as the user's question.
- Ensure the response maintains continuity with the conversation history.
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
- If you don't know the answer, just say so.
- Do not include information not provided by the Document Chunks."""
@@ -382,6 +382,6 @@ When handling information with timestamps:
- Ensure the response maintains continuity with the conversation history.
- Organize answer in sections focusing on one main point or aspect of the answer
- Use clear and descriptive section titles that reflect the content
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
- If you don't know the answer, just say so. Do not make anything up.
- Do not include information not provided by the Data Sources."""

View File

@@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None:
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
class TokenTracker:
"""Track token usage for LLM calls."""
def __init__(self):
self.reset()
def reset(self):
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.call_count = 0
def add_usage(self, token_counts):
"""Add token usage from one LLM call.
Args:
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
"""
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
self.completion_tokens += token_counts.get("completion_tokens", 0)
# If total_tokens is provided, use it directly; otherwise calculate the sum
if "total_tokens" in token_counts:
self.total_tokens += token_counts["total_tokens"]
else:
self.total_tokens += token_counts.get(
"prompt_tokens", 0
) + token_counts.get("completion_tokens", 0)
self.call_count += 1
def get_usage(self):
"""Get current usage statistics."""
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"call_count": self.call_count,
}
def __str__(self):
usage = self.get_usage()
return (
f"LLM call count: {usage['call_count']}, "
f"Prompt tokens: {usage['prompt_tokens']}, "
f"Completion tokens: {usage['completion_tokens']}, "
f"Total tokens: {usage['total_tokens']}"
)