From 621540a54e9dd043b2dab0a366cf214641fb5dd5 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 00:10:37 +0100 Subject: [PATCH] cleaned code --- lightrag/base.py | 2 ++ lightrag/lightrag.py | 6 ++++-- lightrag/llm.py | 6 +++++- lightrag/operate.py | 2 +- lightrag/utils.py | 24 ++++++++++++++---------- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 29335494..42f6d1e9 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -107,9 +107,11 @@ class BaseVectorStorage(StorageNameSpace): raise NotImplementedError async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" raise NotImplementedError diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ce86e938..af241f65 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -524,7 +524,6 @@ class LightRAG: embedding_func=None, ) - # What's for, Is this nessisary ? if self.llm_response_cache and hasattr( self.llm_response_cache, "global_config" ): @@ -1252,7 +1251,7 @@ class LightRAG: """ return await self.doc_status.get_status_counts() - async def adelete_by_doc_id(self, doc_id: str): + async def adelete_by_doc_id(self, doc_id: str) -> None: """Delete a document and all its related data Args: @@ -1269,6 +1268,9 @@ class LightRAG: # 2. Get all related chunks chunks = await self.text_chunks.get_by_id(doc_id) + if not chunks: + return + chunk_ids = list(chunks.keys()) logger.debug(f"Found {len(chunk_ids)} chunks to delete") diff --git a/lightrag/llm.py b/lightrag/llm.py index 3ca17725..b4baef68 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -66,7 +66,11 @@ class MultiModel: return self._models[self._current_model] async def llm_model_func( - self, prompt, system_prompt=None, history_messages=[], **kwargs + self, + prompt: str, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] = [], + **kwargs: Any, ) -> str: kwargs.pop("model", None) # stop from overwriting the custom model name kwargs.pop("keyword_extraction", None) diff --git a/lightrag/operate.py b/lightrag/operate.py index d6cc9f3c..f2bd6218 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1608,7 +1608,7 @@ async def kg_query_with_keywords( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, -) -> str: +) -> str | AsyncIterator[str]: """ Refactored kg_query that does NOT extract keywords by itself. It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. diff --git a/lightrag/utils.py b/lightrag/utils.py index c94e23cb..9b18d0c2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -9,7 +9,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union, List, Optional +from typing import Any, Callable, Union, List, Optional import xml.etree.ElementTree as ET import bs4 @@ -67,7 +67,7 @@ class EmbeddingFunc: @dataclass class ReasoningResponse: - reasoning_content: str + reasoning_content: str | None response_content: str tag: str @@ -109,7 +109,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]: raise e from None -def compute_args_hash(*args, cache_type: str = None) -> str: +def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: """Compute a hash for the given arguments. Args: *args: Arguments to hash @@ -220,11 +220,13 @@ def clean_str(input: Any) -> str: return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) -def is_float_regex(value): +def is_float_regex(value: str) -> bool: return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) -def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): +def truncate_list_by_token_size( + list_data: list[Any], key: Callable[[Any], str], max_token_size: int +) -> list[int]: """Truncate a list of data by token size""" if max_token_size <= 0: return [] @@ -334,7 +336,7 @@ def xml_to_json(xml_file): return None -def process_combine_contexts(hl, ll): +def process_combine_contexts(hl: str, ll: str): header = None list_hl = csv_string_to_list(hl.strip()) list_ll = csv_string_to_list(ll.strip()) @@ -640,7 +642,9 @@ def exists_func(obj, func_name: str) -> bool: return False -def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str: +def get_conversation_turns( + conversation_history: list[dict[str, Any]], num_turns: int +) -> str: """ Process conversation history to get the specified number of complete turns. @@ -652,8 +656,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> Formatted string of the conversation history """ # Group messages into turns - turns = [] - messages = [] + turns: list[list[dict[str, Any]]] = [] + messages: list[dict[str, Any]] = [] # First, filter out keyword extraction messages for msg in conversation_history: @@ -687,7 +691,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> turns = turns[-num_turns:] # Format the turns into a string - formatted_turns = [] + formatted_turns: list[str] = [] for turn in turns: formatted_turns.extend( [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]