cleaned code

This commit is contained in:
Yannick Stephan
2025-02-15 00:10:37 +01:00
parent 805da7b95b
commit 621540a54e
5 changed files with 26 additions and 14 deletions

View File

@@ -107,9 +107,11 @@ class BaseVectorStorage(StorageNameSpace):
raise NotImplementedError
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError

View File

@@ -524,7 +524,6 @@ class LightRAG:
embedding_func=None,
)
# What's for, Is this nessisary ?
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
@@ -1252,7 +1251,7 @@ class LightRAG:
"""
return await self.doc_status.get_status_counts()
async def adelete_by_doc_id(self, doc_id: str):
async def adelete_by_doc_id(self, doc_id: str) -> None:
"""Delete a document and all its related data
Args:
@@ -1269,6 +1268,9 @@ class LightRAG:
# 2. Get all related chunks
chunks = await self.text_chunks.get_by_id(doc_id)
if not chunks:
return
chunk_ids = list(chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")

View File

@@ -66,7 +66,11 @@ class MultiModel:
return self._models[self._current_model]
async def llm_model_func(
self, prompt, system_prompt=None, history_messages=[], **kwargs
self,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] = [],
**kwargs: Any,
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None)

View File

@@ -1608,7 +1608,7 @@ async def kg_query_with_keywords(
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str:
) -> str | AsyncIterator[str]:
"""
Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.

View File

@@ -9,7 +9,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union, List, Optional
from typing import Any, Callable, Union, List, Optional
import xml.etree.ElementTree as ET
import bs4
@@ -67,7 +67,7 @@ class EmbeddingFunc:
@dataclass
class ReasoningResponse:
reasoning_content: str
reasoning_content: str | None
response_content: str
tag: str
@@ -109,7 +109,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
raise e from None
def compute_args_hash(*args, cache_type: str = None) -> str:
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
@@ -220,11 +220,13 @@ def clean_str(input: Any) -> str:
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value):
def is_float_regex(value: str) -> bool:
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
) -> list[int]:
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
@@ -334,7 +336,7 @@ def xml_to_json(xml_file):
return None
def process_combine_contexts(hl, ll):
def process_combine_contexts(hl: str, ll: str):
header = None
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
@@ -640,7 +642,9 @@ def exists_func(obj, func_name: str) -> bool:
return False
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
def get_conversation_turns(
conversation_history: list[dict[str, Any]], num_turns: int
) -> str:
"""
Process conversation history to get the specified number of complete turns.
@@ -652,8 +656,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
Formatted string of the conversation history
"""
# Group messages into turns
turns = []
messages = []
turns: list[list[dict[str, Any]]] = []
messages: list[dict[str, Any]] = []
# First, filter out keyword extraction messages
for msg in conversation_history:
@@ -687,7 +691,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
turns = turns[-num_turns:]
# Format the turns into a string
formatted_turns = []
formatted_turns: list[str] = []
for turn in turns:
formatted_turns.extend(
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]