Merge pull request #1207 from choizhang/track-tokens
feat: Add TokenTracker to track token usage for LLM calls
This commit is contained in:
153
examples/lightrag_gemini_track_token_demo.py
Normal file
153
examples/lightrag_gemini_track_token_demo.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
# pip install -q -U google-genai to use gemini as a client
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import numpy as np
|
||||||
|
import nest_asyncio
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||||
|
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
||||||
|
from lightrag.utils import setup_logger
|
||||||
|
from lightrag.utils import TokenTracker
|
||||||
|
|
||||||
|
setup_logger("lightrag", level="DEBUG")
|
||||||
|
|
||||||
|
# Apply nest_asyncio to solve event loop issues
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
token_tracker = TokenTracker()
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
# 1. Initialize the GenAI Client with your Gemini API Key
|
||||||
|
client = genai.Client(api_key=gemini_api_key)
|
||||||
|
|
||||||
|
# 2. Combine prompts: system prompt, history, and user prompt
|
||||||
|
if history_messages is None:
|
||||||
|
history_messages = []
|
||||||
|
|
||||||
|
combined_prompt = ""
|
||||||
|
if system_prompt:
|
||||||
|
combined_prompt += f"{system_prompt}\n"
|
||||||
|
|
||||||
|
for msg in history_messages:
|
||||||
|
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
||||||
|
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||||
|
|
||||||
|
# Finally, add the new user prompt
|
||||||
|
combined_prompt += f"user: {prompt}"
|
||||||
|
|
||||||
|
# 3. Call the Gemini model
|
||||||
|
response = client.models.generate_content(
|
||||||
|
model="gemini-2.0-flash",
|
||||||
|
contents=[combined_prompt],
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
max_output_tokens=5000, temperature=0, top_k=10
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Get token counts with null safety
|
||||||
|
usage = getattr(response, "usage_metadata", None)
|
||||||
|
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
||||||
|
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
||||||
|
total_tokens = getattr(usage, "total_token_count", 0) or (
|
||||||
|
prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
|
||||||
|
# 5. Return the response text
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await siliconcloud_embedding(
|
||||||
|
texts,
|
||||||
|
model="BAAI/bge-m3",
|
||||||
|
api_key=siliconflow_api_key,
|
||||||
|
max_token_size=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_rag():
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
entity_extract_max_gleaning=1,
|
||||||
|
enable_llm_cache=True,
|
||||||
|
enable_llm_cache_for_entity_extract=True,
|
||||||
|
embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=1024,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=embedding_func,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag.initialize_storages()
|
||||||
|
await initialize_pipeline_status()
|
||||||
|
|
||||||
|
return rag
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initialize RAG instance
|
||||||
|
rag = asyncio.run(initialize_rag())
|
||||||
|
|
||||||
|
# Reset tracker before processing queries
|
||||||
|
token_tracker.reset()
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display final token usage after main query
|
||||||
|
print("Token usage:", token_tracker.get_usage())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
114
examples/lightrag_siliconcloud_track_token_demo.py
Normal file
114
examples/lightrag_siliconcloud_track_token_demo.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm.openai import openai_complete_if_cache
|
||||||
|
from lightrag.llm.siliconcloud import siliconcloud_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from lightrag.utils import TokenTracker
|
||||||
|
import numpy as np
|
||||||
|
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
token_tracker = TokenTracker()
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||||
|
base_url="https://api.siliconflow.cn/v1/",
|
||||||
|
token_tracker=token_tracker,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await siliconcloud_embedding(
|
||||||
|
texts,
|
||||||
|
model="BAAI/bge-m3",
|
||||||
|
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||||
|
max_token_size=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# function test
|
||||||
|
async def test_funcs():
|
||||||
|
# Reset tracker before processing queries
|
||||||
|
token_tracker.reset()
|
||||||
|
|
||||||
|
result = await llm_model_func("How are you?")
|
||||||
|
print("llm_model_func: ", result)
|
||||||
|
|
||||||
|
# Display final token usage after main query
|
||||||
|
print("Token usage:", token_tracker.get_usage())
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_rag():
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=1024, max_token_size=512, func=embedding_func
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag.initialize_storages()
|
||||||
|
await initialize_pipeline_status()
|
||||||
|
|
||||||
|
return rag
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initialize RAG instance
|
||||||
|
rag = asyncio.run(initialize_rag())
|
||||||
|
|
||||||
|
# Reset tracker before processing queries
|
||||||
|
token_tracker.reset()
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display final token usage after main query
|
||||||
|
print("Token usage:", token_tracker.get_usage())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -58,6 +58,7 @@ async def openai_complete_if_cache(
|
|||||||
history_messages: list[dict[str, Any]] | None = None,
|
history_messages: list[dict[str, Any]] | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
token_tracker: Any | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
@@ -154,6 +155,15 @@ async def openai_complete_if_cache(
|
|||||||
|
|
||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = safe_unicode_decode(content.encode("utf-8"))
|
content = safe_unicode_decode(content.encode("utf-8"))
|
||||||
|
|
||||||
|
if token_tracker and hasattr(response, "usage"):
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||||
|
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
|
||||||
|
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None:
|
|||||||
f"Storage implementation '{storage_name}' requires the following "
|
f"Storage implementation '{storage_name}' requires the following "
|
||||||
f"environment variables: {', '.join(missing_vars)}"
|
f"environment variables: {', '.join(missing_vars)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenTracker:
|
||||||
|
"""Track token usage for LLM calls."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.prompt_tokens = 0
|
||||||
|
self.completion_tokens = 0
|
||||||
|
self.total_tokens = 0
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
def add_usage(self, token_counts):
|
||||||
|
"""Add token usage from one LLM call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
||||||
|
"""
|
||||||
|
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
|
||||||
|
self.completion_tokens += token_counts.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
||||||
|
if "total_tokens" in token_counts:
|
||||||
|
self.total_tokens += token_counts["total_tokens"]
|
||||||
|
else:
|
||||||
|
self.total_tokens += token_counts.get(
|
||||||
|
"prompt_tokens", 0
|
||||||
|
) + token_counts.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
self.call_count += 1
|
||||||
|
|
||||||
|
def get_usage(self):
|
||||||
|
"""Get current usage statistics."""
|
||||||
|
return {
|
||||||
|
"prompt_tokens": self.prompt_tokens,
|
||||||
|
"completion_tokens": self.completion_tokens,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
"call_count": self.call_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
usage = self.get_usage()
|
||||||
|
return (
|
||||||
|
f"LLM call count: {usage['call_count']}, "
|
||||||
|
f"Prompt tokens: {usage['prompt_tokens']}, "
|
||||||
|
f"Completion tokens: {usage['completion_tokens']}, "
|
||||||
|
f"Total tokens: {usage['total_tokens']}"
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user