diff --git a/README.md b/README.md
index a1454792..a24c9b72 100644
--- a/README.md
+++ b/README.md
@@ -594,7 +594,7 @@ if __name__ == "__main__":
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
-| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
+| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"]}`: sets example limit and output language | `example_number: all examples, language: English` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index ff14787f..36576368 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -50,16 +50,17 @@ from .storage import (
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
+ # Get the caller's module and package
+ import inspect
+
+ caller_frame = inspect.currentframe().f_back
+ module = inspect.getmodule(caller_frame)
+ package = module.__package__ if module else None
+
def import_class(*args, **kwargs):
- import inspect
import importlib
- # Get the caller's module and package
- caller_frame = inspect.currentframe().f_back
- module = inspect.getmodule(caller_frame)
- package = module.__package__ if module else None
-
- # Import the module using importlib with package context
+ # Import the module using importlib
module = importlib.import_module(module_name, package=package)
# Get the class from the module and instantiate it
diff --git a/lightrag/llm.py b/lightrag/llm.py
index d725ea85..636f03cb 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -30,6 +30,7 @@ from .utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
+ logger,
)
import sys
@@ -63,12 +64,18 @@ async def openai_complete_if_cache(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
kwargs.pop("hashing_kv", None)
+ kwargs.pop("keyword_extraction", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
+ # 添加日志输出
+ logger.debug("===== Query Input to LLM =====")
+ logger.debug(f"Query: {prompt}")
+ logger.debug(f"System prompt: {system_prompt}")
+ logger.debug("Full context:")
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 468f4b2f..8b8ad85b 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -260,6 +260,9 @@ async def extract_entities(
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
+ entity_types = global_config["addon_params"].get(
+ "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
+ )
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
examples = "\n".join(
@@ -272,7 +275,7 @@ async def extract_entities(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
- entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
+ entity_types=",".join(entity_types),
language=language,
)
# add example's format
@@ -283,7 +286,7 @@ async def extract_entities(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
- entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
+ entity_types=",".join(entity_types),
examples=examples,
language=language,
)
@@ -412,15 +415,17 @@ async def extract_entities(
):
all_relationships_data.append(await result)
- if not len(all_entities_data):
- logger.warning("Didn't extract any entities, maybe your LLM is not working")
- return None
- if not len(all_relationships_data):
+ if not len(all_entities_data) and not len(all_relationships_data):
logger.warning(
- "Didn't extract any relationships, maybe your LLM is not working"
+ "Didn't extract any entities and relationships, maybe your LLM is not working"
)
return None
+ if not len(all_entities_data):
+ logger.warning("Didn't extract any entities")
+ if not len(all_relationships_data):
+ logger.warning("Didn't extract any relationships")
+
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
@@ -630,6 +635,13 @@ async def _build_query_context(
text_chunks_db,
query_param,
)
+ if (
+ hl_entities_context == ""
+ and hl_relations_context == ""
+ and hl_text_units_context == ""
+ ):
+ logger.warn("No high level context found. Switching to local mode.")
+ query_param.mode = "local"
if query_param.mode == "hybrid":
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
@@ -865,7 +877,7 @@ async def _get_edge_data(
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):
- return None
+ return "", "", ""
edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
diff --git a/lightrag/prompt.py b/lightrag/prompt.py
index b62f02b5..9d9e6034 100644
--- a/lightrag/prompt.py
+++ b/lightrag/prompt.py
@@ -8,7 +8,7 @@ PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
-PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
+PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
@@ -268,14 +268,19 @@ PROMPTS[
Question 1: {original_prompt}
Question 2: {cached_prompt}
-Please evaluate:
+Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
1. Whether these two questions are semantically similar
2. Whether the answer to Question 2 can be used to answer Question 1
-
-Please provide a similarity score between 0 and 1, where:
-0: Completely unrelated or answer cannot be reused
+Similarity score criteria:
+0: Completely unrelated or answer cannot be reused, including but not limited to:
+ - The questions have different topics
+ - The locations mentioned in the questions are different
+ - The times mentioned in the questions are different
+ - The specific individuals mentioned in the questions are different
+ - The specific events mentioned in the questions are different
+ - The background information in the questions is different
+ - The key conditions in the questions are different
1: Identical and answer can be directly reused
0.5: Partially related and answer needs modification to be used
-
Return only a number between 0-1, without any additional content.
"""
diff --git a/lightrag/storage.py b/lightrag/storage.py
index 007d6534..4c043893 100644
--- a/lightrag/storage.py
+++ b/lightrag/storage.py
@@ -107,10 +107,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
embeddings = await f
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list)
- for i, d in enumerate(list_data):
- d["__vector__"] = embeddings[i]
- results = self._client.upsert(datas=list_data)
- return results
+ if len(embeddings) == len(list_data):
+ for i, d in enumerate(list_data):
+ d["__vector__"] = embeddings[i]
+ results = self._client.upsert(datas=list_data)
+ return results
+ else:
+ # sometimes the embedding is not returned correctly. just log it.
+ logger.error(
+ f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
+ )
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 0220af06..bdb47592 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -17,6 +17,17 @@ import tiktoken
from lightrag.prompt import PROMPTS
+
+class UnlimitedSemaphore:
+ """A context manager that allows unlimited access."""
+
+ async def __aenter__(self):
+ pass
+
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
+
+
ENCODER = None
logger = logging.getLogger("lightrag")
@@ -42,9 +53,17 @@ class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
+ concurrent_limit: int = 16
+
+ def __post_init__(self):
+ if self.concurrent_limit != 0:
+ self._semaphore = asyncio.Semaphore(self.concurrent_limit)
+ else:
+ self._semaphore = UnlimitedSemaphore()
async def __call__(self, *args, **kwargs) -> np.ndarray:
- return await self.func(*args, **kwargs)
+ async with self._semaphore:
+ return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]: