Merge pull request #1207 from choizhang/track-tokens

feat:  Add TokenTracker to track token usage for LLM calls
This commit is contained in:
zrguo
2025-03-28 16:39:55 +11:00
committed by GitHub
4 changed files with 327 additions and 0 deletions

View File

@@ -58,6 +58,7 @@ async def openai_complete_if_cache(
history_messages: list[dict[str, Any]] | None = None,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
**kwargs: Any,
) -> str:
if history_messages is None:
@@ -154,6 +155,15 @@ async def openai_complete_if_cache(
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
return content

View File

@@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None:
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
class TokenTracker:
"""Track token usage for LLM calls."""
def __init__(self):
self.reset()
def reset(self):
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.call_count = 0
def add_usage(self, token_counts):
"""Add token usage from one LLM call.
Args:
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
"""
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
self.completion_tokens += token_counts.get("completion_tokens", 0)
# If total_tokens is provided, use it directly; otherwise calculate the sum
if "total_tokens" in token_counts:
self.total_tokens += token_counts["total_tokens"]
else:
self.total_tokens += token_counts.get(
"prompt_tokens", 0
) + token_counts.get("completion_tokens", 0)
self.call_count += 1
def get_usage(self):
"""Get current usage statistics."""
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"call_count": self.call_count,
}
def __str__(self):
usage = self.get_usage()
return (
f"LLM call count: {usage['call_count']}, "
f"Prompt tokens: {usage['prompt_tokens']}, "
f"Completion tokens: {usage['completion_tokens']}, "
f"Total tokens: {usage['total_tokens']}"
)