From 8488229a29824aab37e93430724b2b6ac1e3646a Mon Sep 17 00:00:00 2001 From: choizhang Date: Fri, 28 Mar 2025 01:25:15 +0800 Subject: [PATCH] feat: Add TokenTracker to track token usage for LLM calls --- examples/lightrag_gemini_track_token_demo.py | 153 ++++++++++++++++++ .../lightrag_siliconcloud_track_token_demo.py | 114 +++++++++++++ lightrag/llm/openai.py | 10 ++ lightrag/utils.py | 50 ++++++ 4 files changed, 327 insertions(+) create mode 100644 examples/lightrag_gemini_track_token_demo.py create mode 100644 examples/lightrag_siliconcloud_track_token_demo.py diff --git a/examples/lightrag_gemini_track_token_demo.py b/examples/lightrag_gemini_track_token_demo.py new file mode 100644 index 00000000..e169a562 --- /dev/null +++ b/examples/lightrag_gemini_track_token_demo.py @@ -0,0 +1,153 @@ +# pip install -q -U google-genai to use gemini as a client + +import os +import asyncio +import numpy as np +import nest_asyncio +from google import genai +from google.genai import types +from dotenv import load_dotenv +from lightrag.utils import EmbeddingFunc +from lightrag import LightRAG, QueryParam +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.llm.siliconcloud import siliconcloud_embedding +from lightrag.utils import setup_logger +from lightrag.utils import TokenTracker + +setup_logger("lightrag", level="DEBUG") + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +load_dotenv() +gemini_api_key = os.getenv("GEMINI_API_KEY") +siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY") + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +token_tracker = TokenTracker() + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + # 1. Initialize the GenAI Client with your Gemini API Key + client = genai.Client(api_key=gemini_api_key) + + # 2. Combine prompts: system prompt, history, and user prompt + if history_messages is None: + history_messages = [] + + combined_prompt = "" + if system_prompt: + combined_prompt += f"{system_prompt}\n" + + for msg in history_messages: + # Each msg is expected to be a dict: {"role": "...", "content": "..."} + combined_prompt += f"{msg['role']}: {msg['content']}\n" + + # Finally, add the new user prompt + combined_prompt += f"user: {prompt}" + + # 3. Call the Gemini model + response = client.models.generate_content( + model="gemini-2.0-flash", + contents=[combined_prompt], + config=types.GenerateContentConfig( + max_output_tokens=5000, temperature=0, top_k=10 + ), + ) + + # 4. Get token counts with null safety + usage = getattr(response, "usage_metadata", None) + prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0 + completion_tokens = getattr(usage, "candidates_token_count", 0) or 0 + total_tokens = getattr(usage, "total_token_count", 0) or ( + prompt_tokens + completion_tokens + ) + + token_counts = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + token_tracker.add_usage(token_counts) + + # 5. Return the response text + return response.text + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + model="BAAI/bge-m3", + api_key=siliconflow_api_key, + max_token_size=512, + ) + + +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + entity_extract_max_gleaning=1, + enable_llm_cache=True, + enable_llm_cache_for_entity_extract=True, + embedding_cache_config={"enabled": True, "similarity_threshold": 0.90}, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=embedding_func, + ), + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + + +def main(): + # Initialize RAG instance + rag = asyncio.run(initialize_rag()) + + # Reset tracker before processing queries + token_tracker.reset() + + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) + + # Display final token usage after main query + print("Token usage:", token_tracker.get_usage()) + + +if __name__ == "__main__": + main() diff --git a/examples/lightrag_siliconcloud_track_token_demo.py b/examples/lightrag_siliconcloud_track_token_demo.py new file mode 100644 index 00000000..fbbe94b4 --- /dev/null +++ b/examples/lightrag_siliconcloud_track_token_demo.py @@ -0,0 +1,114 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import openai_complete_if_cache +from lightrag.llm.siliconcloud import siliconcloud_embedding +from lightrag.utils import EmbeddingFunc +from lightrag.utils import TokenTracker +import numpy as np +from lightrag.kg.shared_storage import initialize_pipeline_status +from dotenv import load_dotenv + +load_dotenv() + +token_tracker = TokenTracker() +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + "Qwen/Qwen2.5-7B-Instruct", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("SILICONFLOW_API_KEY"), + base_url="https://api.siliconflow.cn/v1/", + token_tracker=token_tracker, + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + model="BAAI/bge-m3", + api_key=os.getenv("SILICONFLOW_API_KEY"), + max_token_size=512, + ) + + +# function test +async def test_funcs(): + # Reset tracker before processing queries + token_tracker.reset() + + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) + + # Display final token usage after main query + print("Token usage:", token_tracker.get_usage()) + + +asyncio.run(test_funcs()) + + +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, max_token_size=512, func=embedding_func + ), + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + + +def main(): + # Initialize RAG instance + rag = asyncio.run(initialize_rag()) + + # Reset tracker before processing queries + token_tracker.reset() + + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) + + # Display final token usage after main query + print("Token usage:", token_tracker.get_usage()) + + +if __name__ == "__main__": + main() diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 70aa0ceb..772e77d5 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -58,6 +58,7 @@ async def openai_complete_if_cache( history_messages: list[dict[str, Any]] | None = None, base_url: str | None = None, api_key: str | None = None, + token_tracker: Any | None = None, **kwargs: Any, ) -> str: if history_messages is None: @@ -154,6 +155,15 @@ async def openai_complete_if_cache( if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) + + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "completion_tokens": getattr(response.usage, "completion_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + return content diff --git a/lightrag/utils.py b/lightrag/utils.py index 07f8c4b7..9732b8f3 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None: f"Storage implementation '{storage_name}' requires the following " f"environment variables: {', '.join(missing_vars)}" ) + + +class TokenTracker: + """Track token usage for LLM calls.""" + + def __init__(self): + self.reset() + + def reset(self): + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_tokens = 0 + self.call_count = 0 + + def add_usage(self, token_counts): + """Add token usage from one LLM call. + + Args: + token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens + """ + self.prompt_tokens += token_counts.get("prompt_tokens", 0) + self.completion_tokens += token_counts.get("completion_tokens", 0) + + # If total_tokens is provided, use it directly; otherwise calculate the sum + if "total_tokens" in token_counts: + self.total_tokens += token_counts["total_tokens"] + else: + self.total_tokens += token_counts.get( + "prompt_tokens", 0 + ) + token_counts.get("completion_tokens", 0) + + self.call_count += 1 + + def get_usage(self): + """Get current usage statistics.""" + return { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.total_tokens, + "call_count": self.call_count, + } + + def __str__(self): + usage = self.get_usage() + return ( + f"LLM call count: {usage['call_count']}, " + f"Prompt tokens: {usage['prompt_tokens']}, " + f"Completion tokens: {usage['completion_tokens']}, " + f"Total tokens: {usage['total_tokens']}" + )