Merge pull request #1051 from HKUDS/dev
Refactor LightRAG for better code organization
This commit is contained in:
@@ -30,11 +30,10 @@ from .namespace import NameSpace, make_namespace
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
extract_keywords_only,
|
||||
kg_query,
|
||||
kg_query_with_keywords,
|
||||
mix_kg_vector_query,
|
||||
naive_query,
|
||||
query_with_keywords,
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||
from .utils import (
|
||||
@@ -45,6 +44,9 @@ from .utils import (
|
||||
encode_string_by_tiktoken,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
get_content_summary,
|
||||
clean_text,
|
||||
check_storage_env_vars,
|
||||
logger,
|
||||
)
|
||||
from .types import KnowledgeGraph
|
||||
@@ -309,7 +311,7 @@ class LightRAG:
|
||||
# Verify storage implementation compatibility
|
||||
verify_storage_implementation(storage_type, storage_name)
|
||||
# Check environment variables
|
||||
# self.check_storage_env_vars(storage_name)
|
||||
check_storage_env_vars(storage_name)
|
||||
|
||||
# Ensure vector_db_storage_cls_kwargs has required fields
|
||||
self.vector_db_storage_cls_kwargs = {
|
||||
@@ -536,11 +538,6 @@ class LightRAG:
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
return storage_class
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str) -> str:
|
||||
"""Clean text by removing null bytes (0x00) and whitespace"""
|
||||
return text.strip().replace("\x00", "")
|
||||
|
||||
def insert(
|
||||
self,
|
||||
input: str | list[str],
|
||||
@@ -602,8 +599,8 @@ class LightRAG:
|
||||
update_storage = False
|
||||
try:
|
||||
# Clean input texts
|
||||
full_text = self.clean_text(full_text)
|
||||
text_chunks = [self.clean_text(chunk) for chunk in text_chunks]
|
||||
full_text = clean_text(full_text)
|
||||
text_chunks = [clean_text(chunk) for chunk in text_chunks]
|
||||
|
||||
# Process cleaned texts
|
||||
if doc_id is None:
|
||||
@@ -682,7 +679,7 @@ class LightRAG:
|
||||
contents = {id_: doc for id_, doc in zip(ids, input)}
|
||||
else:
|
||||
# Clean input text and remove duplicates
|
||||
input = list(set(self.clean_text(doc) for doc in input))
|
||||
input = list(set(clean_text(doc) for doc in input))
|
||||
# Generate contents dict of MD5 hash IDs and documents
|
||||
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
|
||||
|
||||
@@ -698,7 +695,7 @@ class LightRAG:
|
||||
new_docs: dict[str, Any] = {
|
||||
id_: {
|
||||
"content": content,
|
||||
"content_summary": self._get_content_summary(content),
|
||||
"content_summary": get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
"status": DocStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
@@ -1063,7 +1060,7 @@ class LightRAG:
|
||||
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||
chunk_to_source_map: dict[str, str] = {}
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
chunk_content = self.clean_text(chunk_data["content"])
|
||||
chunk_content = clean_text(chunk_data["content"])
|
||||
source_id = chunk_data["source_id"]
|
||||
tokens = len(
|
||||
encode_string_by_tiktoken(
|
||||
@@ -1296,8 +1293,17 @@ class LightRAG:
|
||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||
):
|
||||
"""
|
||||
1. Extract keywords from the 'query' using new function in operate.py.
|
||||
2. Then run the standard aquery() flow with the final prompt (formatted_question).
|
||||
Query with separate keyword extraction step.
|
||||
|
||||
This method extracts keywords from the query first, then uses them for the query.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
prompt: Additional prompt for the query
|
||||
param: Query parameters
|
||||
|
||||
Returns:
|
||||
Query response
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
@@ -1308,66 +1314,29 @@ class LightRAG:
|
||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||
Async version of query_with_separate_keyword_extraction.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
prompt: Additional prompt for the query
|
||||
param: Query parameters
|
||||
|
||||
Returns:
|
||||
Query response or async iterator
|
||||
"""
|
||||
# ---------------------
|
||||
# STEP 1: Keyword Extraction
|
||||
# ---------------------
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
response = await query_with_keywords(
|
||||
query=query,
|
||||
prompt=prompt,
|
||||
param=param,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entities_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
chunks_vdb=self.chunks_vdb,
|
||||
text_chunks_db=self.text_chunks,
|
||||
global_config=asdict(self),
|
||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||
hashing_kv=self.llm_response_cache,
|
||||
)
|
||||
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# ---------------------
|
||||
# STEP 2: Final Query Logic
|
||||
# ---------------------
|
||||
|
||||
# Create a new string with the prompt and the keywords
|
||||
ll_keywords_str = ", ".join(ll_keywords)
|
||||
hl_keywords_str = ", ".join(hl_keywords)
|
||||
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
|
||||
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
response = await kg_query_with_keywords(
|
||||
formatted_question,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
response = await naive_query(
|
||||
formatted_question,
|
||||
self.chunks_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
response = await mix_kg_vector_query(
|
||||
formatted_question,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
||||
await self._query_done()
|
||||
return response
|
||||
|
||||
@@ -1465,21 +1434,6 @@ class LightRAG:
|
||||
]
|
||||
)
|
||||
|
||||
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
||||
"""Get summary of document content
|
||||
|
||||
Args:
|
||||
content: Original document content
|
||||
max_length: Maximum length of summary
|
||||
|
||||
Returns:
|
||||
Truncated content with ellipsis if needed
|
||||
"""
|
||||
content = content.strip()
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
return content[:max_length] + "..."
|
||||
|
||||
async def get_processing_status(self) -> dict[str, int]:
|
||||
"""Get current document processing status counts
|
||||
|
||||
|
@@ -1916,3 +1916,90 @@ async def kg_query_with_keywords(
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def query_with_keywords(
|
||||
query: str,
|
||||
prompt: str,
|
||||
param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Extract keywords from the query and then use them for retrieving information.
|
||||
|
||||
1. Extracts high-level and low-level keywords from the query
|
||||
2. Formats the query with the extracted keywords and prompt
|
||||
3. Uses the appropriate query method based on param.mode
|
||||
|
||||
Args:
|
||||
query: The user's query
|
||||
prompt: Additional prompt to prepend to the query
|
||||
param: Query parameters
|
||||
knowledge_graph_inst: Knowledge graph storage
|
||||
entities_vdb: Entities vector database
|
||||
relationships_vdb: Relationships vector database
|
||||
chunks_vdb: Document chunks vector database
|
||||
text_chunks_db: Text chunks storage
|
||||
global_config: Global configuration
|
||||
hashing_kv: Cache storage
|
||||
|
||||
Returns:
|
||||
Query response or async iterator
|
||||
"""
|
||||
# Extract keywords
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
param=param,
|
||||
global_config=global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# Create a new string with the prompt and the keywords
|
||||
ll_keywords_str = ", ".join(ll_keywords)
|
||||
hl_keywords_str = ", ".join(hl_keywords)
|
||||
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
|
||||
|
||||
# Use appropriate query method based on mode
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
return await kg_query_with_keywords(
|
||||
formatted_question,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
return await naive_query(
|
||||
formatted_question,
|
||||
chunks_vdb,
|
||||
text_chunks_db,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
return await mix_kg_vector_query(
|
||||
formatted_question,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
chunks_vdb,
|
||||
text_chunks_db,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
@@ -890,3 +890,52 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
||||
|
||||
def get_content_summary(content: str, max_length: int = 100) -> str:
|
||||
"""Get summary of document content
|
||||
|
||||
Args:
|
||||
content: Original document content
|
||||
max_length: Maximum length of summary
|
||||
|
||||
Returns:
|
||||
Truncated content with ellipsis if needed
|
||||
"""
|
||||
content = content.strip()
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
return content[:max_length] + "..."
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""Clean text by removing null bytes (0x00) and whitespace
|
||||
|
||||
Args:
|
||||
text: Input text to clean
|
||||
|
||||
Returns:
|
||||
Cleaned text
|
||||
"""
|
||||
return text.strip().replace("\x00", "")
|
||||
|
||||
|
||||
def check_storage_env_vars(storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
|
||||
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user