Merge branch 'main' into upload-error
This commit is contained in:
153
examples/lightrag_gemini_track_token_demo.py
Normal file
153
examples/lightrag_gemini_track_token_demo.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# pip install -q -U google-genai to use gemini as a client
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import nest_asyncio
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
||||
from lightrag.utils import setup_logger
|
||||
from lightrag.utils import TokenTracker
|
||||
|
||||
setup_logger("lightrag", level="DEBUG")
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
load_dotenv()
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
token_tracker = TokenTracker()
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
||||
client = genai.Client(api_key=gemini_api_key)
|
||||
|
||||
# 2. Combine prompts: system prompt, history, and user prompt
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
combined_prompt = ""
|
||||
if system_prompt:
|
||||
combined_prompt += f"{system_prompt}\n"
|
||||
|
||||
for msg in history_messages:
|
||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||
|
||||
# Finally, add the new user prompt
|
||||
combined_prompt += f"user: {prompt}"
|
||||
|
||||
# 3. Call the Gemini model
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=[combined_prompt],
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=5000, temperature=0, top_k=10
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Get token counts with null safety
|
||||
usage = getattr(response, "usage_metadata", None)
|
||||
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
||||
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
||||
total_tokens = getattr(usage, "total_token_count", 0) or (
|
||||
prompt_tokens + completion_tokens
|
||||
)
|
||||
|
||||
token_counts = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
# 5. Return the response text
|
||||
return response.text
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await siliconcloud_embedding(
|
||||
texts,
|
||||
model="BAAI/bge-m3",
|
||||
api_key=siliconflow_api_key,
|
||||
max_token_size=512,
|
||||
)
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
entity_extract_max_gleaning=1,
|
||||
enable_llm_cache=True,
|
||||
enable_llm_cache_for_entity_extract=True,
|
||||
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize RAG instance
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
# Reset tracker before processing queries
|
||||
token_tracker.reset()
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||
)
|
||||
)
|
||||
|
||||
# Display final token usage after main query
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
114
examples/lightrag_siliconcloud_track_token_demo.py
Normal file
114
examples/lightrag_siliconcloud_track_token_demo.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import openai_complete_if_cache
|
||||
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.utils import TokenTracker
|
||||
import numpy as np
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
token_tracker = TokenTracker()
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"Qwen/Qwen2.5-7B-Instruct",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
base_url="https://api.siliconflow.cn/v1/",
|
||||
token_tracker=token_tracker,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await siliconcloud_embedding(
|
||||
texts,
|
||||
model="BAAI/bge-m3",
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
max_token_size=512,
|
||||
)
|
||||
|
||||
|
||||
# function test
|
||||
async def test_funcs():
|
||||
# Reset tracker before processing queries
|
||||
token_tracker.reset()
|
||||
|
||||
result = await llm_model_func("How are you?")
|
||||
print("llm_model_func: ", result)
|
||||
|
||||
# Display final token usage after main query
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
|
||||
|
||||
asyncio.run(test_funcs())
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024, max_token_size=512, func=embedding_func
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize RAG instance
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
# Reset tracker before processing queries
|
||||
token_tracker.reset()
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||
)
|
||||
)
|
||||
|
||||
# Display final token usage after main query
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -58,6 +58,7 @@ async def openai_complete_if_cache(
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if history_messages is None:
|
||||
@@ -154,6 +155,15 @@ async def openai_complete_if_cache(
|
||||
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
|
||||
if token_tracker and hasattr(response, "usage"):
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
|
@@ -1038,7 +1038,7 @@ async def mix_kg_vector_query(
|
||||
# Include time information in content
|
||||
formatted_chunks = []
|
||||
for c in maybe_trun_chunks:
|
||||
chunk_text = c["content"]
|
||||
chunk_text = "File path: " + c["file_path"] + "\n" + c["content"]
|
||||
if c["created_at"]:
|
||||
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
|
||||
formatted_chunks.append(chunk_text)
|
||||
@@ -1334,9 +1334,9 @@ async def _get_node_data(
|
||||
)
|
||||
relations_context = list_of_list_to_csv(relations_section_list)
|
||||
|
||||
text_units_section_list = [["id", "content"]]
|
||||
text_units_section_list = [["id", "content", "file_path"]]
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_section_list.append([i, t["content"], t["file_path"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context, relations_context, text_units_context
|
||||
|
||||
@@ -1597,9 +1597,9 @@ async def _get_edge_data(
|
||||
)
|
||||
entities_context = list_of_list_to_csv(entites_section_list)
|
||||
|
||||
text_units_section_list = [["id", "content"]]
|
||||
text_units_section_list = [["id", "content", "file_path"]]
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"]])
|
||||
text_units_section_list.append([i, t["content"], t["file_path"]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context, relations_context, text_units_context
|
||||
|
||||
@@ -1785,7 +1785,12 @@ async def naive_query(
|
||||
f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||
)
|
||||
|
||||
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
|
||||
section = "\n--New Chunk--\n".join(
|
||||
[
|
||||
"File path: " + c["file_path"] + "\n" + c["content"]
|
||||
for c in maybe_trun_chunks
|
||||
]
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return section
|
||||
|
@@ -222,7 +222,7 @@ When handling relationships with timestamps:
|
||||
- Use markdown formatting with appropriate section headings
|
||||
- Please respond in the same language as the user's question.
|
||||
- Ensure the response maintains continuity with the conversation history.
|
||||
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
|
||||
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
|
||||
- If you don't know the answer, just say so.
|
||||
- Do not make anything up. Do not include information not provided by the Knowledge Base."""
|
||||
|
||||
@@ -320,7 +320,7 @@ When handling content with timestamps:
|
||||
- Use markdown formatting with appropriate section headings
|
||||
- Please respond in the same language as the user's question.
|
||||
- Ensure the response maintains continuity with the conversation history.
|
||||
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
|
||||
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
|
||||
- If you don't know the answer, just say so.
|
||||
- Do not include information not provided by the Document Chunks."""
|
||||
|
||||
@@ -382,6 +382,6 @@ When handling information with timestamps:
|
||||
- Ensure the response maintains continuity with the conversation history.
|
||||
- Organize answer in sections focusing on one main point or aspect of the answer
|
||||
- Use clear and descriptive section titles that reflect the content
|
||||
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
|
||||
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
|
||||
- If you don't know the answer, just say so. Do not make anything up.
|
||||
- Do not include information not provided by the Data Sources."""
|
||||
|
@@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None:
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
|
||||
class TokenTracker:
|
||||
"""Track token usage for LLM calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.call_count = 0
|
||||
|
||||
def add_usage(self, token_counts):
|
||||
"""Add token usage from one LLM call.
|
||||
|
||||
Args:
|
||||
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
||||
"""
|
||||
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
|
||||
self.completion_tokens += token_counts.get("completion_tokens", 0)
|
||||
|
||||
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
||||
if "total_tokens" in token_counts:
|
||||
self.total_tokens += token_counts["total_tokens"]
|
||||
else:
|
||||
self.total_tokens += token_counts.get(
|
||||
"prompt_tokens", 0
|
||||
) + token_counts.get("completion_tokens", 0)
|
||||
|
||||
self.call_count += 1
|
||||
|
||||
def get_usage(self):
|
||||
"""Get current usage statistics."""
|
||||
return {
|
||||
"prompt_tokens": self.prompt_tokens,
|
||||
"completion_tokens": self.completion_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"call_count": self.call_count,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
usage = self.get_usage()
|
||||
return (
|
||||
f"LLM call count: {usage['call_count']}, "
|
||||
f"Prompt tokens: {usage['prompt_tokens']}, "
|
||||
f"Completion tokens: {usage['completion_tokens']}, "
|
||||
f"Total tokens: {usage['total_tokens']}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user