Merge branch 'main' into graph-viewer-webui
This commit is contained in:
20
README.md
20
README.md
@@ -3,7 +3,7 @@
|
|||||||
<table border="0" width="100%">
|
<table border="0" width="100%">
|
||||||
<tr>
|
<tr>
|
||||||
<td width="100" align="center">
|
<td width="100" align="center">
|
||||||
<img src="https://github.com/user-attachments/assets/cb5b8fc1-0859-4f7c-8ec3-63c8ec7aa54b" width="80" height="80" alt="lightrag">
|
<img src="https://i-blog.csdnimg.cn/direct/0d97ea81439442a19ac3972ad537a811.png" width="80" height="80" alt="lightrag">
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
<div>
|
<div>
|
||||||
@@ -355,16 +355,26 @@ In order to run this experiment on low RAM GPU you should select small model and
|
|||||||
```python
|
```python
|
||||||
class QueryParam:
|
class QueryParam:
|
||||||
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
|
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
|
||||||
|
"""Specifies the retrieval mode:
|
||||||
|
- "local": Focuses on context-dependent information.
|
||||||
|
- "global": Utilizes global knowledge.
|
||||||
|
- "hybrid": Combines local and global retrieval methods.
|
||||||
|
- "naive": Performs a basic search without advanced techniques.
|
||||||
|
- "mix": Integrates knowledge graph and vector retrieval.
|
||||||
|
"""
|
||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
|
"""If True, only returns the retrieved context without generating a response."""
|
||||||
response_type: str = "Multiple Paragraphs"
|
response_type: str = "Multiple Paragraphs"
|
||||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
||||||
top_k: int = 60
|
top_k: int = 60
|
||||||
# Number of tokens for the original chunks.
|
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||||
max_token_for_text_unit: int = 4000
|
max_token_for_text_unit: int = 4000
|
||||||
# Number of tokens for the relationship descriptions
|
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||||
max_token_for_global_context: int = 4000
|
max_token_for_global_context: int = 4000
|
||||||
# Number of tokens for the entity descriptions
|
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
||||||
max_token_for_local_context: int = 4000
|
max_token_for_local_context: int = 4000
|
||||||
|
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
> default value of Top_k can be change by environment variables TOP_K.
|
> default value of Top_k can be change by environment variables TOP_K.
|
||||||
|
@@ -24,6 +24,10 @@ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
|
|||||||
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
||||||
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
||||||
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
||||||
|
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
|
||||||
|
print(f"BASE_URL: {BASE_URL}")
|
||||||
|
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
|
||||||
|
print(f"API_KEY: {API_KEY}")
|
||||||
|
|
||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
@@ -36,10 +40,12 @@ async def llm_model_func(
|
|||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
LLM_MODEL,
|
model=LLM_MODEL,
|
||||||
prompt,
|
prompt=prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
|
base_url=BASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,8 +55,10 @@ async def llm_model_func(
|
|||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await openai_embed(
|
return await openai_embed(
|
||||||
texts,
|
texts=texts,
|
||||||
model=EMBEDDING_MODEL,
|
model=EMBEDDING_MODEL,
|
||||||
|
base_url=BASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
101
examples/lightrag_api_openai_compatible_demo_simplified.py
Normal file
101
examples/lightrag_api_openai_compatible_demo_simplified.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
import asyncio
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
# Apply nest_asyncio to solve event loop issues
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
DEFAULT_RAG_DIR = "index_default"
|
||||||
|
|
||||||
|
# Configure working directory
|
||||||
|
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
||||||
|
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||||
|
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
|
||||||
|
print(f"LLM_MODEL: {LLM_MODEL}")
|
||||||
|
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small")
|
||||||
|
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
||||||
|
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
||||||
|
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
||||||
|
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
|
||||||
|
print(f"BASE_URL: {BASE_URL}")
|
||||||
|
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
|
||||||
|
print(f"API_KEY: {API_KEY}")
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
# LLM model function
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
model=LLM_MODEL,
|
||||||
|
prompt=prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
base_url=BASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Embedding function
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await openai_embed(
|
||||||
|
texts=texts,
|
||||||
|
model=EMBEDDING_MODEL,
|
||||||
|
base_url=BASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_dim():
|
||||||
|
test_text = ["This is a test sentence."]
|
||||||
|
embedding = await embedding_func(test_text)
|
||||||
|
embedding_dim = embedding.shape[1]
|
||||||
|
print(f"{embedding_dim=}")
|
||||||
|
return embedding_dim
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize RAG instance
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=asyncio.run(get_embedding_dim()),
|
||||||
|
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
|
||||||
|
func=embedding_func,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform local search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform global search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform hybrid search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm.openai import gpt_4o_mini_complete
|
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
@@ -10,6 +10,7 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
|
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
|
embedding_func=openai_embed,
|
||||||
llm_model_func=gpt_4o_mini_complete,
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
# llm_model_func=gpt_4o_complete
|
# llm_model_func=gpt_4o_complete
|
||||||
)
|
)
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||||
|
|
||||||
__version__ = "1.1.5"
|
__version__ = "1.1.6"
|
||||||
__author__ = "Zirui Guo"
|
__author__ = "Zirui Guo"
|
||||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||||
|
@@ -27,30 +27,54 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueryParam:
|
class QueryParam:
|
||||||
|
"""Configuration parameters for query execution in LightRAG."""
|
||||||
|
|
||||||
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
|
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
|
||||||
|
"""Specifies the retrieval mode:
|
||||||
|
- "local": Focuses on context-dependent information.
|
||||||
|
- "global": Utilizes global knowledge.
|
||||||
|
- "hybrid": Combines local and global retrieval methods.
|
||||||
|
- "naive": Performs a basic search without advanced techniques.
|
||||||
|
- "mix": Integrates knowledge graph and vector retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
|
"""If True, only returns the retrieved context without generating a response."""
|
||||||
|
|
||||||
only_need_prompt: bool = False
|
only_need_prompt: bool = False
|
||||||
|
"""If True, only returns the generated prompt without producing a response."""
|
||||||
|
|
||||||
response_type: str = "Multiple Paragraphs"
|
response_type: str = "Multiple Paragraphs"
|
||||||
|
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
|
||||||
|
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
"""If True, enables streaming output for real-time responses."""
|
||||||
|
|
||||||
top_k: int = int(os.getenv("TOP_K", "60"))
|
top_k: int = int(os.getenv("TOP_K", "60"))
|
||||||
# Number of document chunks to retrieve.
|
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
|
||||||
# top_n: int = 10
|
|
||||||
# Number of tokens for the original chunks.
|
|
||||||
max_token_for_text_unit: int = 4000
|
max_token_for_text_unit: int = 4000
|
||||||
# Number of tokens for the relationship descriptions
|
"""Maximum number of tokens allowed for each retrieved text chunk."""
|
||||||
|
|
||||||
max_token_for_global_context: int = 4000
|
max_token_for_global_context: int = 4000
|
||||||
# Number of tokens for the entity descriptions
|
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
|
||||||
|
|
||||||
max_token_for_local_context: int = 4000
|
max_token_for_local_context: int = 4000
|
||||||
|
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||||
|
|
||||||
hl_keywords: list[str] = field(default_factory=list)
|
hl_keywords: list[str] = field(default_factory=list)
|
||||||
|
"""List of high-level keywords to prioritize in retrieval."""
|
||||||
|
|
||||||
ll_keywords: list[str] = field(default_factory=list)
|
ll_keywords: list[str] = field(default_factory=list)
|
||||||
# Conversation history support
|
"""List of low-level keywords to refine retrieval focus."""
|
||||||
conversation_history: list[dict[str, str]] = field(
|
|
||||||
default_factory=list
|
conversation_history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
) # Format: [{"role": "user/assistant", "content": "message"}]
|
"""Stores past conversation history to maintain context.
|
||||||
history_turns: int = (
|
Format: [{"role": "user/assistant", "content": "message"}].
|
||||||
3 # Number of complete conversation turns (user-assistant pairs) to consider
|
"""
|
||||||
)
|
|
||||||
|
history_turns: int = 3
|
||||||
|
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@@ -109,6 +109,22 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
if v["status"] == DocStatus.PENDING
|
if v["status"] == DocStatus.PENDING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""Get all processed documents"""
|
||||||
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.PROCESSED
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""Get all processing documents"""
|
||||||
|
return {
|
||||||
|
k: DocProcessingStatus(**v)
|
||||||
|
for k, v in self._data.items()
|
||||||
|
if v["status"] == DocStatus.PROCESSING
|
||||||
|
}
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
"""Save data to file after indexing"""
|
"""Save data to file after indexing"""
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
@@ -109,38 +109,65 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LightRAG:
|
class LightRAG:
|
||||||
|
"""LightRAG: Simple and Fast Retrieval-Augmented Generation."""
|
||||||
|
|
||||||
working_dir: str = field(
|
working_dir: str = field(
|
||||||
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
|
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
|
||||||
)
|
)
|
||||||
# Default not to use embedding cache
|
"""Directory where cache and temporary files are stored."""
|
||||||
embedding_cache_config: dict = field(
|
|
||||||
|
embedding_cache_config: dict[str, Any] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"similarity_threshold": 0.95,
|
"similarity_threshold": 0.95,
|
||||||
"use_llm_check": False,
|
"use_llm_check": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
"""Configuration for embedding cache.
|
||||||
|
- enabled: If True, enables caching to avoid redundant computations.
|
||||||
|
- similarity_threshold: Minimum similarity score to use cached embeddings.
|
||||||
|
- use_llm_check: If True, validates cached embeddings using an LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
kv_storage: str = field(default="JsonKVStorage")
|
kv_storage: str = field(default="JsonKVStorage")
|
||||||
|
"""Storage backend for key-value data."""
|
||||||
|
|
||||||
vector_storage: str = field(default="NanoVectorDBStorage")
|
vector_storage: str = field(default="NanoVectorDBStorage")
|
||||||
|
"""Storage backend for vector embeddings."""
|
||||||
|
|
||||||
graph_storage: str = field(default="NetworkXStorage")
|
graph_storage: str = field(default="NetworkXStorage")
|
||||||
|
"""Storage backend for knowledge graphs."""
|
||||||
|
|
||||||
# logging
|
# Logging
|
||||||
current_log_level = logger.level
|
current_log_level = logger.level
|
||||||
log_level: str = field(default=current_log_level)
|
log_level: int = field(default=current_log_level)
|
||||||
|
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
|
||||||
|
|
||||||
log_dir: str = field(default=os.getcwd())
|
log_dir: str = field(default=os.getcwd())
|
||||||
|
"""Directory where logs are stored. Defaults to the current working directory."""
|
||||||
|
|
||||||
# text chunking
|
# Text chunking
|
||||||
chunk_token_size: int = 1200
|
chunk_token_size: int = 1200
|
||||||
|
"""Maximum number of tokens per text chunk when splitting documents."""
|
||||||
|
|
||||||
chunk_overlap_token_size: int = 100
|
chunk_overlap_token_size: int = 100
|
||||||
|
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
||||||
|
|
||||||
tiktoken_model_name: str = "gpt-4o-mini"
|
tiktoken_model_name: str = "gpt-4o-mini"
|
||||||
|
"""Model name used for tokenization when chunking text."""
|
||||||
|
|
||||||
# entity extraction
|
# Entity extraction
|
||||||
entity_extract_max_gleaning: int = 1
|
entity_extract_max_gleaning: int = 1
|
||||||
entity_summary_to_max_tokens: int = 500
|
"""Maximum number of entity extraction attempts for ambiguous content."""
|
||||||
|
|
||||||
# node embedding
|
entity_summary_to_max_tokens: int = 500
|
||||||
|
"""Maximum number of tokens used for summarizing extracted entities."""
|
||||||
|
|
||||||
|
# Node embedding
|
||||||
node_embedding_algorithm: str = "node2vec"
|
node_embedding_algorithm: str = "node2vec"
|
||||||
node2vec_params: dict = field(
|
"""Algorithm used for node embedding in knowledge graphs."""
|
||||||
|
|
||||||
|
node2vec_params: dict[str, int] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"dimensions": 1536,
|
"dimensions": 1536,
|
||||||
"num_walks": 10,
|
"num_walks": 10,
|
||||||
@@ -150,26 +177,56 @@ class LightRAG:
|
|||||||
"random_seed": 3,
|
"random_seed": 3,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
"""Configuration for the node2vec embedding algorithm:
|
||||||
|
- dimensions: Number of dimensions for embeddings.
|
||||||
|
- num_walks: Number of random walks per node.
|
||||||
|
- walk_length: Number of steps per random walk.
|
||||||
|
- window_size: Context window size for training.
|
||||||
|
- iterations: Number of iterations for training.
|
||||||
|
- random_seed: Seed value for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_func: EmbeddingFunc = None
|
||||||
|
"""Function for computing text embeddings. Must be set before use."""
|
||||||
|
|
||||||
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
|
||||||
embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
|
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
|
"""Batch size for embedding computations."""
|
||||||
|
|
||||||
embedding_func_max_async: int = 16
|
embedding_func_max_async: int = 16
|
||||||
|
"""Maximum number of concurrent embedding function calls."""
|
||||||
|
|
||||||
|
# LLM Configuration
|
||||||
|
llm_model_func: callable = None
|
||||||
|
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||||
|
|
||||||
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
"""Name of the LLM model used for generating responses."""
|
||||||
|
|
||||||
# LLM
|
|
||||||
llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
|
|
||||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
|
||||||
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
|
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768"))
|
||||||
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
|
"""Maximum number of tokens allowed per LLM response."""
|
||||||
llm_model_kwargs: dict = field(default_factory=dict)
|
|
||||||
|
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16"))
|
||||||
|
"""Maximum number of concurrent LLM calls."""
|
||||||
|
|
||||||
|
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Additional keyword arguments passed to the LLM model function."""
|
||||||
|
|
||||||
|
# Storage
|
||||||
|
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Additional parameters for vector database storage."""
|
||||||
|
|
||||||
# storage
|
|
||||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
|
||||||
namespace_prefix: str = field(default="")
|
namespace_prefix: str = field(default="")
|
||||||
|
"""Prefix for namespacing stored data across different environments."""
|
||||||
|
|
||||||
enable_llm_cache: bool = True
|
enable_llm_cache: bool = True
|
||||||
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
|
"""Enables caching for LLM responses to avoid redundant computations."""
|
||||||
|
|
||||||
enable_llm_cache_for_entity_extract: bool = True
|
enable_llm_cache_for_entity_extract: bool = True
|
||||||
|
"""If True, enables caching for entity extraction steps to reduce LLM costs."""
|
||||||
|
|
||||||
|
# Extensions
|
||||||
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Dictionary for additional parameters and extensions."""
|
||||||
|
|
||||||
# extension
|
# extension
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -177,8 +234,8 @@ class LightRAG:
|
|||||||
convert_response_to_json
|
convert_response_to_json
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add new field for document status storage type
|
|
||||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||||
|
"""Storage type for tracking document processing statuses."""
|
||||||
|
|
||||||
# Custom Chunking Function
|
# Custom Chunking Function
|
||||||
chunking_func: Callable[
|
chunking_func: Callable[
|
||||||
@@ -486,7 +543,7 @@ class LightRAG:
|
|||||||
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
|
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
|
||||||
|
|
||||||
if not new_docs:
|
if not new_docs:
|
||||||
logger.info("All documents have been processed or are duplicates")
|
logger.info("No new unique documents were found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. Store status document
|
# 4. Store status document
|
||||||
@@ -503,15 +560,16 @@ class LightRAG:
|
|||||||
each chunk for entity and relation extraction, and updating the
|
each chunk for entity and relation extraction, and updating the
|
||||||
document status.
|
document status.
|
||||||
|
|
||||||
1. Get all pending and failed documents
|
1. Get all pending, failed, and abnormally terminated processing documents.
|
||||||
2. Split document content into chunks
|
2. Split document content into chunks
|
||||||
3. Process each chunk for entity and relation extraction
|
3. Process each chunk for entity and relation extraction
|
||||||
4. Update the document status
|
4. Update the document status
|
||||||
"""
|
"""
|
||||||
# 1. get all pending and failed documents
|
# 1. Get all pending, failed, and abnormally terminated processing documents.
|
||||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||||
|
|
||||||
# Fetch failed documents
|
processing_docs = await self.doc_status.get_processing_docs()
|
||||||
|
to_process_docs.update(processing_docs)
|
||||||
failed_docs = await self.doc_status.get_failed_docs()
|
failed_docs = await self.doc_status.get_failed_docs()
|
||||||
to_process_docs.update(failed_docs)
|
to_process_docs.update(failed_docs)
|
||||||
pendings_docs = await self.doc_status.get_pending_docs()
|
pendings_docs = await self.doc_status.get_pending_docs()
|
||||||
@@ -542,6 +600,7 @@ class LightRAG:
|
|||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSING,
|
"status": DocStatus.PROCESSING,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
|
"content": status_doc.content,
|
||||||
"content_summary": status_doc.content_summary,
|
"content_summary": status_doc.content_summary,
|
||||||
"content_length": status_doc.content_length,
|
"content_length": status_doc.content_length,
|
||||||
"created_at": status_doc.created_at,
|
"created_at": status_doc.created_at,
|
||||||
@@ -578,6 +637,10 @@ class LightRAG:
|
|||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.PROCESSED,
|
"status": DocStatus.PROCESSED,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
|
"content": status_doc.content,
|
||||||
|
"content_summary": status_doc.content_summary,
|
||||||
|
"content_length": status_doc.content_length,
|
||||||
|
"created_at": status_doc.created_at,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -591,6 +654,10 @@ class LightRAG:
|
|||||||
doc_status_id: {
|
doc_status_id: {
|
||||||
"status": DocStatus.FAILED,
|
"status": DocStatus.FAILED,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"content": status_doc.content,
|
||||||
|
"content_summary": status_doc.content_summary,
|
||||||
|
"content_length": status_doc.content_length,
|
||||||
|
"created_at": status_doc.created_at,
|
||||||
"updated_at": datetime.now().isoformat(),
|
"updated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -103,17 +103,19 @@ async def openai_complete_if_cache(
|
|||||||
) -> str:
|
) -> str:
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
if api_key:
|
if not api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
|
||||||
default_headers = {
|
default_headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
openai_async_client = (
|
openai_async_client = (
|
||||||
AsyncOpenAI(default_headers=default_headers)
|
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||||
if base_url is None
|
if base_url is None
|
||||||
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
|
else AsyncOpenAI(
|
||||||
|
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
@@ -294,17 +296,19 @@ async def openai_embed(
|
|||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
if api_key:
|
if not api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
|
||||||
default_headers = {
|
default_headers = {
|
||||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
openai_async_client = (
|
openai_async_client = (
|
||||||
AsyncOpenAI(default_headers=default_headers)
|
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||||
if base_url is None
|
if base_url is None
|
||||||
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
|
else AsyncOpenAI(
|
||||||
|
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
response = await openai_async_client.embeddings.create(
|
response = await openai_async_client.embeddings.create(
|
||||||
model=model, input=texts, encoding_format="float"
|
model=model, input=texts, encoding_format="float"
|
||||||
|
@@ -1504,7 +1504,7 @@ async def naive_query(
|
|||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, "default", cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
if cached_response is not None:
|
if cached_response is not None:
|
||||||
return cached_response
|
return cached_response
|
||||||
|
Reference in New Issue
Block a user