feat: Add TokenTracker to track token usage for LLM calls
This commit is contained in:
153
examples/lightrag_gemini_track_token_demo.py
Normal file
153
examples/lightrag_gemini_track_token_demo.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# pip install -q -U google-genai to use gemini as a client
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import nest_asyncio
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
||||
from lightrag.utils import setup_logger
|
||||
from lightrag.utils import TokenTracker
|
||||
|
||||
setup_logger("lightrag", level="DEBUG")
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
load_dotenv()
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
token_tracker = TokenTracker()
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
# 1. Initialize the GenAI Client with your Gemini API Key
|
||||
client = genai.Client(api_key=gemini_api_key)
|
||||
|
||||
# 2. Combine prompts: system prompt, history, and user prompt
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
combined_prompt = ""
|
||||
if system_prompt:
|
||||
combined_prompt += f"{system_prompt}\n"
|
||||
|
||||
for msg in history_messages:
|
||||
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
||||
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||
|
||||
# Finally, add the new user prompt
|
||||
combined_prompt += f"user: {prompt}"
|
||||
|
||||
# 3. Call the Gemini model
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=[combined_prompt],
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=5000, temperature=0, top_k=10
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Get token counts with null safety
|
||||
usage = getattr(response, "usage_metadata", None)
|
||||
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
||||
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
||||
total_tokens = getattr(usage, "total_token_count", 0) or (
|
||||
prompt_tokens + completion_tokens
|
||||
)
|
||||
|
||||
token_counts = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
# 5. Return the response text
|
||||
return response.text
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await siliconcloud_embedding(
|
||||
texts,
|
||||
model="BAAI/bge-m3",
|
||||
api_key=siliconflow_api_key,
|
||||
max_token_size=512,
|
||||
)
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
entity_extract_max_gleaning=1,
|
||||
enable_llm_cache=True,
|
||||
enable_llm_cache_for_entity_extract=True,
|
||||
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize RAG instance
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
# Reset tracker before processing queries
|
||||
token_tracker.reset()
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||
)
|
||||
)
|
||||
|
||||
# Display final token usage after main query
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
114
examples/lightrag_siliconcloud_track_token_demo.py
Normal file
114
examples/lightrag_siliconcloud_track_token_demo.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import openai_complete_if_cache
|
||||
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.utils import TokenTracker
|
||||
import numpy as np
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
token_tracker = TokenTracker()
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"Qwen/Qwen2.5-7B-Instruct",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
base_url="https://api.siliconflow.cn/v1/",
|
||||
token_tracker=token_tracker,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await siliconcloud_embedding(
|
||||
texts,
|
||||
model="BAAI/bge-m3",
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
max_token_size=512,
|
||||
)
|
||||
|
||||
|
||||
# function test
|
||||
async def test_funcs():
|
||||
# Reset tracker before processing queries
|
||||
token_tracker.reset()
|
||||
|
||||
result = await llm_model_func("How are you?")
|
||||
print("llm_model_func: ", result)
|
||||
|
||||
# Display final token usage after main query
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
|
||||
|
||||
asyncio.run(test_funcs())
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024, max_token_size=512, func=embedding_func
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize RAG instance
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
# Reset tracker before processing queries
|
||||
token_tracker.reset()
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||
)
|
||||
)
|
||||
|
||||
# Display final token usage after main query
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -58,6 +58,7 @@ async def openai_complete_if_cache(
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if history_messages is None:
|
||||
@@ -154,6 +155,15 @@ async def openai_complete_if_cache(
|
||||
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
|
||||
if token_tracker and hasattr(response, "usage"):
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
|
@@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None:
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
|
||||
class TokenTracker:
|
||||
"""Track token usage for LLM calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.call_count = 0
|
||||
|
||||
def add_usage(self, token_counts):
|
||||
"""Add token usage from one LLM call.
|
||||
|
||||
Args:
|
||||
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
||||
"""
|
||||
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
|
||||
self.completion_tokens += token_counts.get("completion_tokens", 0)
|
||||
|
||||
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
||||
if "total_tokens" in token_counts:
|
||||
self.total_tokens += token_counts["total_tokens"]
|
||||
else:
|
||||
self.total_tokens += token_counts.get(
|
||||
"prompt_tokens", 0
|
||||
) + token_counts.get("completion_tokens", 0)
|
||||
|
||||
self.call_count += 1
|
||||
|
||||
def get_usage(self):
|
||||
"""Get current usage statistics."""
|
||||
return {
|
||||
"prompt_tokens": self.prompt_tokens,
|
||||
"completion_tokens": self.completion_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"call_count": self.call_count,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
usage = self.get_usage()
|
||||
return (
|
||||
f"LLM call count: {usage['call_count']}, "
|
||||
f"Prompt tokens: {usage['prompt_tokens']}, "
|
||||
f"Completion tokens: {usage['completion_tokens']}, "
|
||||
f"Total tokens: {usage['total_tokens']}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user