diff --git a/examples/lightrag_gemini_track_token_demo.py b/examples/lightrag_gemini_track_token_demo.py index e169a562..a72fc717 100644 --- a/examples/lightrag_gemini_track_token_demo.py +++ b/examples/lightrag_gemini_track_token_demo.py @@ -115,38 +115,36 @@ def main(): # Initialize RAG instance rag = asyncio.run(initialize_rag()) - # Reset tracker before processing queries - token_tracker.reset() - with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="naive") + # Context Manager Method + with token_tracker: + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) ) - ) - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="local") + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) ) - ) - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="global") + print( + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="global"), + ) ) - ) - print( - rag.query( - "What are the top themes in this story?", param=QueryParam(mode="hybrid") + print( + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid"), + ) ) - ) - - # Display final token usage after main query - print("Token usage:", token_tracker.get_usage()) if __name__ == "__main__": diff --git a/examples/lightrag_siliconcloud_track_token_demo.py b/examples/lightrag_siliconcloud_track_token_demo.py index fbbe94b4..d82a30bc 100644 --- a/examples/lightrag_siliconcloud_track_token_demo.py +++ b/examples/lightrag_siliconcloud_track_token_demo.py @@ -44,14 +44,10 @@ async def embedding_func(texts: list[str]) -> np.ndarray: # function test async def test_funcs(): - # Reset tracker before processing queries - token_tracker.reset() - - result = await llm_model_func("How are you?") - print("llm_model_func: ", result) - - # Display final token usage after main query - print("Token usage:", token_tracker.get_usage()) + # Context Manager Method + with token_tracker: + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) asyncio.run(test_funcs()) diff --git a/lightrag/utils.py b/lightrag/utils.py index 44a85425..4515e080 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -962,6 +962,13 @@ class TokenTracker: def __init__(self): self.reset() + def __enter__(self): + self.reset() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + print(self) + def reset(self): self.prompt_tokens = 0 self.completion_tokens = 0