chore: added pre-commit-hooks and ruff formatting for commit-hooks

This commit is contained in:
Sanketh Kumar
2024-10-19 09:43:17 +05:30
parent b854ab4737
commit 744dad339d
26 changed files with 635 additions and 393 deletions

View File

@@ -1,4 +1,4 @@
from .lightrag import LightRAG, QueryParam
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "0.0.6"
__author__ = "Zirui Guo"

View File

@@ -12,15 +12,16 @@ TextChunkSchema = TypedDict(
T = TypeVar("T")
@dataclass
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
top_k: int = 60
max_token_for_text_unit: int = 4000
max_token_for_text_unit: int = 4000
max_token_for_global_context: int = 4000
max_token_for_local_context: int = 4000
max_token_for_local_context: int = 4000
@dataclass
@@ -36,6 +37,7 @@ class StorageNameSpace:
"""commit the storage operations after querying"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
@@ -72,7 +75,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
async def drop(self):
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
@@ -113,4 +116,4 @@ class BaseGraphStorage(StorageNameSpace):
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
raise NotImplementedError("Node embedding is not used in lightrag.")

View File

@@ -3,10 +3,12 @@ import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from typing import Type, cast
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
from .llm import (
gpt_4o_mini_complete,
openai_embedding,
)
from .operate import (
chunking_by_token_size,
extract_entities,
@@ -37,6 +39,7 @@ from .base import (
QueryParam,
)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
@@ -69,7 +72,6 @@ class LightRAG:
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
@@ -77,13 +79,13 @@ class LightRAG:
)
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
@@ -98,11 +100,11 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
def __post_init__(self):
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -133,30 +135,24 @@ class LightRAG:
self.embedding_func
)
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"}
)
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = (
self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb = (
self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
)
@@ -177,7 +173,7 @@ class LightRAG:
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning(f"All docs are already in the storage")
logger.warning("All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
@@ -203,7 +199,7 @@ class LightRAG:
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage")
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
@@ -246,7 +242,7 @@ class LightRAG:
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local":
response = await local_query(
@@ -290,7 +286,6 @@ class LightRAG:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def _query_done(self):
tasks = []
@@ -299,5 +294,3 @@ class LightRAG:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)

View File

@@ -1,9 +1,7 @@
import os
import copy
import json
import botocore
import aioboto3
import botocore.errorfactory
import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -13,24 +11,34 @@ from tenacity import (
wait_exponential,
retry_if_exception_type,
)
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
import copy
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
@@ -64,43 +72,56 @@ class BedrockError(Exception):
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[],
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message['content'] = [{'text': message['content']}]
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({'role': "user", 'content': [{'text': prompt}]})
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {
'modelId': model,
'messages': messages
}
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args['system'] = [{'text': system_prompt}]
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
'max_tokens': "maxTokens",
'top_p': "topP",
'stop_sequences': "stopSequences"
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
args['inferenceConfig'] = {}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
if hashing_kv is not None:
await hashing_kv.upsert({
args_hash: {
'return': response['output']['message']['content'][0]['text'],
'model': model
await hashing_kv.upsert(
{
args_hash: {
"return": response["output"]["message"]["content"][0]["text"],
"model": model,
}
}
})
)
return response["output"]["message"]["content"][0]["text"]
return response['output']['message']['content'][0]['text']
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = model
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
if hf_tokenizer.pad_token == None:
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
if hf_tokenizer.pad_token is None:
# print("use eos token")
hf_tokenizer.pad_token = hf_tokenizer.eos_token
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
input_prompt = ''
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
try:
ori_message = copy.deepcopy(messages)
if messages[0]['role'] == "system":
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
if messages[0]["role"] == "system":
messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
len_message = len(ori_message)
for msgid in range(len_message):
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
output = hf_model.generate(
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response_text, "model": model}}
)
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
return result
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -241,7 +286,7 @@ async def bedrock_complete(
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
model_name,
prompt,
@@ -250,10 +295,11 @@ async def hf_model_complete(
**kwargs,
)
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,
prompt,
@@ -262,17 +308,25 @@ async def ollama_model_complete(
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embedding(
texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps({
'inputText': text,
# 'dimensions': embedding_dim,
'embeddingTypes': ["float"]
})
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({
'inputText': text
})
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
@@ -315,29 +378,27 @@ async def bedrock_embedding(
modelId=model,
body=body,
accept="application/json",
contentType="application/json"
contentType="application/json",
)
response_body = await response.get('body').json()
response_body = await response.get("body").json()
embed_texts.append(response_body['embedding'])
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps({
'texts': texts,
'input_type': "search_document",
'truncate': "NONE"
})
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json"
contentType="application/json",
)
response_body = json.loads(response.get('body').read())
response_body = json.loads(response.get("body").read())
embed_texts = response_body['embeddings']
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
@@ -345,12 +406,15 @@ async def bedrock_embedding(
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
embed_text = []
for text in texts:
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
return embed_text
if __name__ == "__main__":
import asyncio
async def main():
result = await gpt_4o_mini_complete('How are you?')
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())

View File

@@ -25,6 +25,7 @@ from .base import (
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
@@ -45,6 +46,7 @@ def chunking_by_token_size(
)
return results
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
description=description,
keywords=keywords,
)
return edge_data
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
@@ -352,7 +355,9 @@ async def extract_entities(
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if not len(all_relationships_data):
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
logger.warning(
"Didn't extract any relationships, maybe your LLM is not working"
)
return None
if entity_vdb is not None:
@@ -370,7 +375,10 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
}
@@ -378,6 +386,7 @@ async def extract_entities(
return knwoledge_graph_inst
async def local_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -393,19 +402,24 @@ async def local_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords)
except json.JSONDecodeError as e:
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords)
keywords = ", ".join(keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -430,11 +444,20 @@ async def local_query(
query,
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_local_query_context(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -516,6 +539,7 @@ async def _build_local_query_context(
```
"""
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
)
return all_edges_data
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -624,20 +650,25 @@ async def global_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords)
except json.JSONDecodeError as e:
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords)
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
@@ -651,12 +682,12 @@ async def global_query(
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -665,11 +696,20 @@ async def global_query(
query,
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_global_query_context(
keywords,
knowledge_graph_inst: BaseGraphStorage,
@@ -679,14 +719,14 @@ async def _build_global_query_context(
query_param: QueryParam,
):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):
return None
edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
)
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather(
@@ -765,6 +805,7 @@ async def _build_global_query_context(
```
"""
async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
for e in edge_datas:
entity_names.add(e["src_id"])
entity_names.add(e["tgt_id"])
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
)
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
return node_datas
async def _find_related_text_unit_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
}
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
all_text_units = sorted(
all_text_units, key=lambda x: x["order"]
)
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
async def hybrid_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -849,24 +889,29 @@ async def hybrid_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords)
ll_keywords = ', '.join(ll_keywords)
except json.JSONDecodeError as e:
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords)
ll_keywords = ', '.join(ll_keywords)
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -897,7 +942,7 @@ async def hybrid_query(
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -906,53 +951,78 @@ async def hybrid_query(
query,
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings
def extract_sections(context):
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
entities = entities_match.group(1) if entities_match else ''
relationships = relationships_match.group(1) if relationships_match else ''
sources = sources_match.group(1) if sources_match else ''
entities_match = re.search(
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# Extract sections from both contexts
if high_level_context==None:
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities, hl_relationships, hl_sources = '','',''
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context==None:
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
ll_entities, ll_relationships, ll_sources = '','',''
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
combined_entities = '\n'.join(combined_entities_set)
combined_entities_set = set(
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
)
combined_entities = "\n".join(combined_entities_set)
# Combine and deduplicate the relationships
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
combined_relationships = '\n'.join(combined_relationships_set)
combined_relationships_set = set(
filter(
None,
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
)
)
combined_relationships = "\n".join(combined_relationships_set)
# Combine and deduplicate the sources
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
combined_sources = '\n'.join(combined_sources_set)
combined_sources_set = set(
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
)
combined_sources = "\n".join(combined_sources_set)
# Format the combined context
return f"""
-----Entities-----
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
{combined_sources}
"""
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
@@ -996,8 +1067,16 @@ async def naive_query(
system_prompt=sys_prompt,
)
if len(response)>len(sys_prompt):
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
return response
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response

View File

@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
PROMPTS[
"entity_extraction"
] = """-Goal-
PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps-
@@ -32,7 +30,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
@@ -146,9 +144,7 @@ PROMPTS[
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
PROMPTS[
"rag_response"
] = """---Role---
PROMPTS["rag_response"] = """---Role---
You are a helpful assistant responding to questions about data in the tables provided.
@@ -241,9 +237,7 @@ Output:
"""
PROMPTS[
"naive_rag_response"
] = """You're a helpful assistant
PROMPTS["naive_rag_response"] = """You're a helpful assistant
Below are the knowledge you know:
{content_data}
---

View File

@@ -1,16 +1,11 @@
import asyncio
import html
import json
import os
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Union, cast
import pickle
import hnswlib
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
import xxhash
from .utils import load_json, logger, write_json
from .base import (
@@ -19,6 +14,7 @@ from .base import (
BaseVectorStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self):
self._data = {}
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self):
self._client.save()
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)

View File

@@ -16,18 +16,22 @@ ENCODER = None
logger = logging.getLogger("lightrag")
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(file_handler)
@dataclass
class EmbeddingFunc:
embedding_dim: int
@@ -36,7 +40,8 @@ class EmbeddingFunc:
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
else:
return None
def convert_response_to_json(response: str) -> dict:
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
return final_decro
def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
def write_json(json_obj, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
content = ENCODER.decode(tokens)
return content
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
if max_token_size <= 0:
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i]
return list_data
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
)
def save_data_to_file(data, file_name):
with open(file_name, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)