Merge branch 'main' into pkaushal/vectordb-chroma
This commit is contained in:
@@ -50,16 +50,17 @@ from .storage import (
|
||||
def lazy_external_import(module_name: str, class_name: str):
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args, **kwargs):
|
||||
import inspect
|
||||
import importlib
|
||||
|
||||
# Get the caller's module and package
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
# Import the module using importlib with package context
|
||||
# Import the module using importlib
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
|
||||
# Get the class from the module and instantiate it
|
||||
|
@@ -30,6 +30,7 @@ from .utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
safe_unicode_decode,
|
||||
logger,
|
||||
)
|
||||
|
||||
import sys
|
||||
@@ -63,12 +64,18 @@ async def openai_complete_if_cache(
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 添加日志输出
|
||||
logger.debug("===== Query Input to LLM =====")
|
||||
logger.debug(f"Query: {prompt}")
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
logger.debug("Full context:")
|
||||
if "response_format" in kwargs:
|
||||
response = await openai_async_client.beta.chat.completions.parse(
|
||||
model=model, messages=messages, **kwargs
|
||||
|
@@ -260,6 +260,9 @@ async def extract_entities(
|
||||
language = global_config["addon_params"].get(
|
||||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||
)
|
||||
entity_types = global_config["addon_params"].get(
|
||||
"entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
|
||||
)
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
|
||||
examples = "\n".join(
|
||||
@@ -272,7 +275,7 @@ async def extract_entities(
|
||||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||||
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
|
||||
entity_types=",".join(entity_types),
|
||||
language=language,
|
||||
)
|
||||
# add example's format
|
||||
@@ -283,7 +286,7 @@ async def extract_entities(
|
||||
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
||||
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
||||
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
||||
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
|
||||
entity_types=",".join(entity_types),
|
||||
examples=examples,
|
||||
language=language,
|
||||
)
|
||||
@@ -412,15 +415,17 @@ async def extract_entities(
|
||||
):
|
||||
all_relationships_data.append(await result)
|
||||
|
||||
if not len(all_entities_data):
|
||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||
return None
|
||||
if not len(all_relationships_data):
|
||||
if not len(all_entities_data) and not len(all_relationships_data):
|
||||
logger.warning(
|
||||
"Didn't extract any relationships, maybe your LLM is not working"
|
||||
"Didn't extract any entities and relationships, maybe your LLM is not working"
|
||||
)
|
||||
return None
|
||||
|
||||
if not len(all_entities_data):
|
||||
logger.warning("Didn't extract any entities")
|
||||
if not len(all_relationships_data):
|
||||
logger.warning("Didn't extract any relationships")
|
||||
|
||||
if entity_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
@@ -630,6 +635,13 @@ async def _build_query_context(
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
if (
|
||||
hl_entities_context == ""
|
||||
and hl_relations_context == ""
|
||||
and hl_text_units_context == ""
|
||||
):
|
||||
logger.warn("No high level context found. Switching to local mode.")
|
||||
query_param.mode = "local"
|
||||
if query_param.mode == "hybrid":
|
||||
entities_context, relations_context, text_units_context = combine_contexts(
|
||||
[hl_entities_context, ll_entities_context],
|
||||
@@ -865,7 +877,7 @@ async def _get_edge_data(
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
|
||||
if not len(results):
|
||||
return None
|
||||
return "", "", ""
|
||||
|
||||
edge_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
|
@@ -8,7 +8,7 @@ PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
|
||||
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
|
||||
PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
||||
|
||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
|
||||
|
||||
PROMPTS["entity_extraction"] = """-Goal-
|
||||
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
||||
@@ -268,14 +268,19 @@ PROMPTS[
|
||||
Question 1: {original_prompt}
|
||||
Question 2: {cached_prompt}
|
||||
|
||||
Please evaluate:
|
||||
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
|
||||
1. Whether these two questions are semantically similar
|
||||
2. Whether the answer to Question 2 can be used to answer Question 1
|
||||
|
||||
Please provide a similarity score between 0 and 1, where:
|
||||
0: Completely unrelated or answer cannot be reused
|
||||
Similarity score criteria:
|
||||
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
||||
- The questions have different topics
|
||||
- The locations mentioned in the questions are different
|
||||
- The times mentioned in the questions are different
|
||||
- The specific individuals mentioned in the questions are different
|
||||
- The specific events mentioned in the questions are different
|
||||
- The background information in the questions is different
|
||||
- The key conditions in the questions are different
|
||||
1: Identical and answer can be directly reused
|
||||
0.5: Partially related and answer needs modification to be used
|
||||
|
||||
Return only a number between 0-1, without any additional content.
|
||||
"""
|
||||
|
@@ -107,10 +107,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
embeddings = await f
|
||||
embeddings_list.append(embeddings)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
else:
|
||||
# sometimes the embedding is not returned correctly. just log it.
|
||||
logger.error(
|
||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
|
@@ -17,6 +17,17 @@ import tiktoken
|
||||
|
||||
from lightrag.prompt import PROMPTS
|
||||
|
||||
|
||||
class UnlimitedSemaphore:
|
||||
"""A context manager that allows unlimited access."""
|
||||
|
||||
async def __aenter__(self):
|
||||
pass
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
|
||||
ENCODER = None
|
||||
|
||||
logger = logging.getLogger("lightrag")
|
||||
@@ -42,9 +53,17 @@ class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
max_token_size: int
|
||||
func: callable
|
||||
concurrent_limit: int = 16
|
||||
|
||||
def __post_init__(self):
|
||||
if self.concurrent_limit != 0:
|
||||
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
||||
else:
|
||||
self._semaphore = UnlimitedSemaphore()
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
return await self.func(*args, **kwargs)
|
||||
async with self._semaphore:
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
|
Reference in New Issue
Block a user