From 8488229a29824aab37e93430724b2b6ac1e3646a Mon Sep 17 00:00:00 2001 From: choizhang Date: Fri, 28 Mar 2025 01:25:15 +0800 Subject: [PATCH 1/2] feat: Add TokenTracker to track token usage for LLM calls --- examples/lightrag_gemini_track_token_demo.py | 153 ++++++++++++++++++ .../lightrag_siliconcloud_track_token_demo.py | 114 +++++++++++++ lightrag/llm/openai.py | 10 ++ lightrag/utils.py | 50 ++++++ 4 files changed, 327 insertions(+) create mode 100644 examples/lightrag_gemini_track_token_demo.py create mode 100644 examples/lightrag_siliconcloud_track_token_demo.py diff --git a/examples/lightrag_gemini_track_token_demo.py b/examples/lightrag_gemini_track_token_demo.py new file mode 100644 index 00000000..e169a562 --- /dev/null +++ b/examples/lightrag_gemini_track_token_demo.py @@ -0,0 +1,153 @@ +# pip install -q -U google-genai to use gemini as a client + +import os +import asyncio +import numpy as np +import nest_asyncio +from google import genai +from google.genai import types +from dotenv import load_dotenv +from lightrag.utils import EmbeddingFunc +from lightrag import LightRAG, QueryParam +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.llm.siliconcloud import siliconcloud_embedding +from lightrag.utils import setup_logger +from lightrag.utils import TokenTracker + +setup_logger("lightrag", level="DEBUG") + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +load_dotenv() +gemini_api_key = os.getenv("GEMINI_API_KEY") +siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY") + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +token_tracker = TokenTracker() + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + # 1. Initialize the GenAI Client with your Gemini API Key + client = genai.Client(api_key=gemini_api_key) + + # 2. Combine prompts: system prompt, history, and user prompt + if history_messages is None: + history_messages = [] + + combined_prompt = "" + if system_prompt: + combined_prompt += f"{system_prompt}\n" + + for msg in history_messages: + # Each msg is expected to be a dict: {"role": "...", "content": "..."} + combined_prompt += f"{msg['role']}: {msg['content']}\n" + + # Finally, add the new user prompt + combined_prompt += f"user: {prompt}" + + # 3. Call the Gemini model + response = client.models.generate_content( + model="gemini-2.0-flash", + contents=[combined_prompt], + config=types.GenerateContentConfig( + max_output_tokens=5000, temperature=0, top_k=10 + ), + ) + + # 4. Get token counts with null safety + usage = getattr(response, "usage_metadata", None) + prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0 + completion_tokens = getattr(usage, "candidates_token_count", 0) or 0 + total_tokens = getattr(usage, "total_token_count", 0) or ( + prompt_tokens + completion_tokens + ) + + token_counts = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + token_tracker.add_usage(token_counts) + + # 5. Return the response text + return response.text + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + model="BAAI/bge-m3", + api_key=siliconflow_api_key, + max_token_size=512, + ) + + +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + entity_extract_max_gleaning=1, + enable_llm_cache=True, + enable_llm_cache_for_entity_extract=True, + embedding_cache_config={"enabled": True, "similarity_threshold": 0.90}, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=embedding_func, + ), + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + + +def main(): + # Initialize RAG instance + rag = asyncio.run(initialize_rag()) + + # Reset tracker before processing queries + token_tracker.reset() + + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) + + # Display final token usage after main query + print("Token usage:", token_tracker.get_usage()) + + +if __name__ == "__main__": + main() diff --git a/examples/lightrag_siliconcloud_track_token_demo.py b/examples/lightrag_siliconcloud_track_token_demo.py new file mode 100644 index 00000000..fbbe94b4 --- /dev/null +++ b/examples/lightrag_siliconcloud_track_token_demo.py @@ -0,0 +1,114 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import openai_complete_if_cache +from lightrag.llm.siliconcloud import siliconcloud_embedding +from lightrag.utils import EmbeddingFunc +from lightrag.utils import TokenTracker +import numpy as np +from lightrag.kg.shared_storage import initialize_pipeline_status +from dotenv import load_dotenv + +load_dotenv() + +token_tracker = TokenTracker() +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + "Qwen/Qwen2.5-7B-Instruct", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("SILICONFLOW_API_KEY"), + base_url="https://api.siliconflow.cn/v1/", + token_tracker=token_tracker, + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + model="BAAI/bge-m3", + api_key=os.getenv("SILICONFLOW_API_KEY"), + max_token_size=512, + ) + + +# function test +async def test_funcs(): + # Reset tracker before processing queries + token_tracker.reset() + + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) + + # Display final token usage after main query + print("Token usage:", token_tracker.get_usage()) + + +asyncio.run(test_funcs()) + + +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, max_token_size=512, func=embedding_func + ), + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + + +def main(): + # Initialize RAG instance + rag = asyncio.run(initialize_rag()) + + # Reset tracker before processing queries + token_tracker.reset() + + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) + + # Display final token usage after main query + print("Token usage:", token_tracker.get_usage()) + + +if __name__ == "__main__": + main() diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 70aa0ceb..772e77d5 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -58,6 +58,7 @@ async def openai_complete_if_cache( history_messages: list[dict[str, Any]] | None = None, base_url: str | None = None, api_key: str | None = None, + token_tracker: Any | None = None, **kwargs: Any, ) -> str: if history_messages is None: @@ -154,6 +155,15 @@ async def openai_complete_if_cache( if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) + + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "completion_tokens": getattr(response.usage, "completion_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + return content diff --git a/lightrag/utils.py b/lightrag/utils.py index 07f8c4b7..9732b8f3 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None: f"Storage implementation '{storage_name}' requires the following " f"environment variables: {', '.join(missing_vars)}" ) + + +class TokenTracker: + """Track token usage for LLM calls.""" + + def __init__(self): + self.reset() + + def reset(self): + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_tokens = 0 + self.call_count = 0 + + def add_usage(self, token_counts): + """Add token usage from one LLM call. + + Args: + token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens + """ + self.prompt_tokens += token_counts.get("prompt_tokens", 0) + self.completion_tokens += token_counts.get("completion_tokens", 0) + + # If total_tokens is provided, use it directly; otherwise calculate the sum + if "total_tokens" in token_counts: + self.total_tokens += token_counts["total_tokens"] + else: + self.total_tokens += token_counts.get( + "prompt_tokens", 0 + ) + token_counts.get("completion_tokens", 0) + + self.call_count += 1 + + def get_usage(self): + """Get current usage statistics.""" + return { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.total_tokens, + "call_count": self.call_count, + } + + def __str__(self): + usage = self.get_usage() + return ( + f"LLM call count: {usage['call_count']}, " + f"Prompt tokens: {usage['prompt_tokens']}, " + f"Completion tokens: {usage['completion_tokens']}, " + f"Total tokens: {usage['total_tokens']}" + ) From 87fbffde14a93f2ca39f2a3f4affbac120a7722d Mon Sep 17 00:00:00 2001 From: zrguo Date: Fri, 28 Mar 2025 13:30:24 +0800 Subject: [PATCH 2/2] fix citation --- lightrag/operate.py | 17 +++++++++++------ lightrag/prompt.py | 6 +++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 9f5eb92b..816227f5 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1038,7 +1038,7 @@ async def mix_kg_vector_query( # Include time information in content formatted_chunks = [] for c in maybe_trun_chunks: - chunk_text = c["content"] + chunk_text = "File path: " + c["file_path"] + "\n" + c["content"] if c["created_at"]: chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}" formatted_chunks.append(chunk_text) @@ -1334,9 +1334,9 @@ async def _get_node_data( ) relations_context = list_of_list_to_csv(relations_section_list) - text_units_section_list = [["id", "content"]] + text_units_section_list = [["id", "content", "file_path"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"]]) + text_units_section_list.append([i, t["content"], t["file_path"]]) text_units_context = list_of_list_to_csv(text_units_section_list) return entities_context, relations_context, text_units_context @@ -1597,9 +1597,9 @@ async def _get_edge_data( ) entities_context = list_of_list_to_csv(entites_section_list) - text_units_section_list = [["id", "content"]] + text_units_section_list = [["id", "content", "file_path"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"]]) + text_units_section_list.append([i, t["content"], t["file_path"]]) text_units_context = list_of_list_to_csv(text_units_section_list) return entities_context, relations_context, text_units_context @@ -1785,7 +1785,12 @@ async def naive_query( f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" ) - section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) + section = "\n--New Chunk--\n".join( + [ + "File path: " + c["file_path"] + "\n" + c["content"] + for c in maybe_trun_chunks + ] + ) if query_param.only_need_context: return section diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 88ebd7fc..d6d46e1f 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -222,7 +222,7 @@ When handling relationships with timestamps: - Use markdown formatting with appropriate section headings - Please respond in the same language as the user's question. - Ensure the response maintains continuity with the conversation history. -- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path) +- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path - If you don't know the answer, just say so. - Do not make anything up. Do not include information not provided by the Knowledge Base.""" @@ -320,7 +320,7 @@ When handling content with timestamps: - Use markdown formatting with appropriate section headings - Please respond in the same language as the user's question. - Ensure the response maintains continuity with the conversation history. -- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path) +- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path - If you don't know the answer, just say so. - Do not include information not provided by the Document Chunks.""" @@ -382,6 +382,6 @@ When handling information with timestamps: - Ensure the response maintains continuity with the conversation history. - Organize answer in sections focusing on one main point or aspect of the answer - Use clear and descriptive section titles that reflect the content -- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path) +- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path - If you don't know the answer, just say so. Do not make anything up. - Do not include information not provided by the Data Sources."""