cleaned code

This commit is contained in:
Yannick Stephan
2025-02-15 00:10:37 +01:00
parent 805da7b95b
commit 621540a54e
5 changed files with 26 additions and 14 deletions

View File

@@ -107,9 +107,11 @@ class BaseVectorStorage(StorageNameSpace):
raise NotImplementedError raise NotImplementedError
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError

View File

@@ -524,7 +524,6 @@ class LightRAG:
embedding_func=None, embedding_func=None,
) )
# What's for, Is this nessisary ?
if self.llm_response_cache and hasattr( if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config" self.llm_response_cache, "global_config"
): ):
@@ -1252,7 +1251,7 @@ class LightRAG:
""" """
return await self.doc_status.get_status_counts() return await self.doc_status.get_status_counts()
async def adelete_by_doc_id(self, doc_id: str): async def adelete_by_doc_id(self, doc_id: str) -> None:
"""Delete a document and all its related data """Delete a document and all its related data
Args: Args:
@@ -1269,6 +1268,9 @@ class LightRAG:
# 2. Get all related chunks # 2. Get all related chunks
chunks = await self.text_chunks.get_by_id(doc_id) chunks = await self.text_chunks.get_by_id(doc_id)
if not chunks:
return
chunk_ids = list(chunks.keys()) chunk_ids = list(chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete") logger.debug(f"Found {len(chunk_ids)} chunks to delete")

View File

@@ -66,7 +66,11 @@ class MultiModel:
return self._models[self._current_model] return self._models[self._current_model]
async def llm_model_func( async def llm_model_func(
self, prompt, system_prompt=None, history_messages=[], **kwargs self,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] = [],
**kwargs: Any,
) -> str: ) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)

View File

@@ -1608,7 +1608,7 @@ async def kg_query_with_keywords(
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str], global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
) -> str: ) -> str | AsyncIterator[str]:
""" """
Refactored kg_query that does NOT extract keywords by itself. Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.

View File

@@ -9,7 +9,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Union, List, Optional from typing import Any, Callable, Union, List, Optional
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import bs4 import bs4
@@ -67,7 +67,7 @@ class EmbeddingFunc:
@dataclass @dataclass
class ReasoningResponse: class ReasoningResponse:
reasoning_content: str reasoning_content: str | None
response_content: str response_content: str
tag: str tag: str
@@ -109,7 +109,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
raise e from None raise e from None
def compute_args_hash(*args, cache_type: str = None) -> str: def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
"""Compute a hash for the given arguments. """Compute a hash for the given arguments.
Args: Args:
*args: Arguments to hash *args: Arguments to hash
@@ -220,11 +220,13 @@ def clean_str(input: Any) -> str:
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value): def is_float_regex(value: str) -> bool:
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
) -> list[int]:
"""Truncate a list of data by token size""" """Truncate a list of data by token size"""
if max_token_size <= 0: if max_token_size <= 0:
return [] return []
@@ -334,7 +336,7 @@ def xml_to_json(xml_file):
return None return None
def process_combine_contexts(hl, ll): def process_combine_contexts(hl: str, ll: str):
header = None header = None
list_hl = csv_string_to_list(hl.strip()) list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip()) list_ll = csv_string_to_list(ll.strip())
@@ -640,7 +642,9 @@ def exists_func(obj, func_name: str) -> bool:
return False return False
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str: def get_conversation_turns(
conversation_history: list[dict[str, Any]], num_turns: int
) -> str:
""" """
Process conversation history to get the specified number of complete turns. Process conversation history to get the specified number of complete turns.
@@ -652,8 +656,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
Formatted string of the conversation history Formatted string of the conversation history
""" """
# Group messages into turns # Group messages into turns
turns = [] turns: list[list[dict[str, Any]]] = []
messages = [] messages: list[dict[str, Any]] = []
# First, filter out keyword extraction messages # First, filter out keyword extraction messages
for msg in conversation_history: for msg in conversation_history:
@@ -687,7 +691,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
turns = turns[-num_turns:] turns = turns[-num_turns:]
# Format the turns into a string # Format the turns into a string
formatted_turns = [] formatted_turns: list[str] = []
for turn in turns: for turn in turns:
formatted_turns.extend( formatted_turns.extend(
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"] [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]