Merge pull request #1229 from choizhang/update-TokenTracker

feat(TokenTracker): Add context manager support to simplify token tracking
This commit is contained in:
zrguo
2025-03-31 11:20:51 +11:00
committed by GitHub
3 changed files with 31 additions and 30 deletions

View File

@@ -115,12 +115,11 @@ def main():
# Initialize RAG instance # Initialize RAG instance
rag = asyncio.run(initialize_rag()) rag = asyncio.run(initialize_rag())
# Reset tracker before processing queries
token_tracker.reset()
with open("./book.txt", "r", encoding="utf-8") as f: with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read()) rag.insert(f.read())
# Context Manager Method
with token_tracker:
print( print(
rag.query( rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive") "What are the top themes in this story?", param=QueryParam(mode="naive")
@@ -135,19 +134,18 @@ def main():
print( print(
rag.query( rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global") "What are the top themes in this story?",
param=QueryParam(mode="global"),
) )
) )
print( print(
rag.query( rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid") "What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
) )
) )
# Display final token usage after main query
print("Token usage:", token_tracker.get_usage())
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -44,15 +44,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
# function test # function test
async def test_funcs(): async def test_funcs():
# Reset tracker before processing queries # Context Manager Method
token_tracker.reset() with token_tracker:
result = await llm_model_func("How are you?") result = await llm_model_func("How are you?")
print("llm_model_func: ", result) print("llm_model_func: ", result)
# Display final token usage after main query
print("Token usage:", token_tracker.get_usage())
asyncio.run(test_funcs()) asyncio.run(test_funcs())

View File

@@ -962,6 +962,13 @@ class TokenTracker:
def __init__(self): def __init__(self):
self.reset() self.reset()
def __enter__(self):
self.reset()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print(self)
def reset(self): def reset(self):
self.prompt_tokens = 0 self.prompt_tokens = 0
self.completion_tokens = 0 self.completion_tokens = 0