cleaned code
This commit is contained in:
@@ -107,9 +107,11 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_entity(self, entity_name: str) -> None:
|
async def delete_entity(self, entity_name: str) -> None:
|
||||||
|
"""Delete a single entity by its name"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
|
"""Delete relations for a given entity by scanning metadata"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@@ -524,7 +524,6 @@ class LightRAG:
|
|||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# What's for, Is this nessisary ?
|
|
||||||
if self.llm_response_cache and hasattr(
|
if self.llm_response_cache and hasattr(
|
||||||
self.llm_response_cache, "global_config"
|
self.llm_response_cache, "global_config"
|
||||||
):
|
):
|
||||||
@@ -1252,7 +1251,7 @@ class LightRAG:
|
|||||||
"""
|
"""
|
||||||
return await self.doc_status.get_status_counts()
|
return await self.doc_status.get_status_counts()
|
||||||
|
|
||||||
async def adelete_by_doc_id(self, doc_id: str):
|
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
||||||
"""Delete a document and all its related data
|
"""Delete a document and all its related data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1269,6 +1268,9 @@ class LightRAG:
|
|||||||
|
|
||||||
# 2. Get all related chunks
|
# 2. Get all related chunks
|
||||||
chunks = await self.text_chunks.get_by_id(doc_id)
|
chunks = await self.text_chunks.get_by_id(doc_id)
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
chunk_ids = list(chunks.keys())
|
chunk_ids = list(chunks.keys())
|
||||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||||
|
|
||||||
|
@@ -66,7 +66,11 @@ class MultiModel:
|
|||||||
return self._models[self._current_model]
|
return self._models[self._current_model]
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
history_messages: list[dict[str, Any]] = [],
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
|
@@ -1608,7 +1608,7 @@ async def kg_query_with_keywords(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Refactored kg_query that does NOT extract keywords by itself.
|
Refactored kg_query that does NOT extract keywords by itself.
|
||||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||||
|
@@ -9,7 +9,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union, List, Optional
|
from typing import Any, Callable, Union, List, Optional
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
import bs4
|
import bs4
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ class EmbeddingFunc:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReasoningResponse:
|
class ReasoningResponse:
|
||||||
reasoning_content: str
|
reasoning_content: str | None
|
||||||
response_content: str
|
response_content: str
|
||||||
tag: str
|
tag: str
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
|||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args, cache_type: str = None) -> str:
|
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||||
"""Compute a hash for the given arguments.
|
"""Compute a hash for the given arguments.
|
||||||
Args:
|
Args:
|
||||||
*args: Arguments to hash
|
*args: Arguments to hash
|
||||||
@@ -220,11 +220,13 @@ def clean_str(input: Any) -> str:
|
|||||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||||
|
|
||||||
|
|
||||||
def is_float_regex(value):
|
def is_float_regex(value: str) -> bool:
|
||||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
def truncate_list_by_token_size(
|
||||||
|
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
||||||
|
) -> list[int]:
|
||||||
"""Truncate a list of data by token size"""
|
"""Truncate a list of data by token size"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
return []
|
return []
|
||||||
@@ -334,7 +336,7 @@ def xml_to_json(xml_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_combine_contexts(hl, ll):
|
def process_combine_contexts(hl: str, ll: str):
|
||||||
header = None
|
header = None
|
||||||
list_hl = csv_string_to_list(hl.strip())
|
list_hl = csv_string_to_list(hl.strip())
|
||||||
list_ll = csv_string_to_list(ll.strip())
|
list_ll = csv_string_to_list(ll.strip())
|
||||||
@@ -640,7 +642,9 @@ def exists_func(obj, func_name: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
|
def get_conversation_turns(
|
||||||
|
conversation_history: list[dict[str, Any]], num_turns: int
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process conversation history to get the specified number of complete turns.
|
Process conversation history to get the specified number of complete turns.
|
||||||
|
|
||||||
@@ -652,8 +656,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|||||||
Formatted string of the conversation history
|
Formatted string of the conversation history
|
||||||
"""
|
"""
|
||||||
# Group messages into turns
|
# Group messages into turns
|
||||||
turns = []
|
turns: list[list[dict[str, Any]]] = []
|
||||||
messages = []
|
messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# First, filter out keyword extraction messages
|
# First, filter out keyword extraction messages
|
||||||
for msg in conversation_history:
|
for msg in conversation_history:
|
||||||
@@ -687,7 +691,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|||||||
turns = turns[-num_turns:]
|
turns = turns[-num_turns:]
|
||||||
|
|
||||||
# Format the turns into a string
|
# Format the turns into a string
|
||||||
formatted_turns = []
|
formatted_turns: list[str] = []
|
||||||
for turn in turns:
|
for turn in turns:
|
||||||
formatted_turns.extend(
|
formatted_turns.extend(
|
||||||
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
||||||
|
Reference in New Issue
Block a user