chore: added pre-commit-hooks and ruff formatting for commit-hooks

This commit is contained in:
Sanketh Kumar
2024-10-19 09:43:17 +05:30
parent b854ab4737
commit 744dad339d
26 changed files with 635 additions and 393 deletions

1
.gitignore vendored
View File

@@ -2,3 +2,4 @@ __pycache__
*.egg-info
dickens/
book.txt
lightrag-dev/

22
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,22 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff-format
- id: ruff
args: [--fix]
- repo: https://github.com/mgedmin/check-manifest
rev: "0.49"
hooks:
- id: check-manifest
stages: [manual]

View File

@@ -83,7 +83,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
<details>
<summary> Using Open AI-like APIs </summary>
LightRAG also support Open AI-like chat/embeddings APIs:
LightRAG also supports Open AI-like chat/embeddings APIs:
```python
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
@@ -187,10 +187,10 @@ with open("./newText.txt") as f:
```
## Evaluation
### Dataset
The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
### Generate Query
LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
<details>
<summary> Prompt </summary>
@@ -384,7 +384,7 @@ def insert_text(rag, file_path):
### Step-2 Generate Queries
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
<details>
<summary> Code </summary>

View File

@@ -1,4 +1,3 @@
import os
import re
import json
import jsonlines
@@ -9,22 +8,22 @@ from openai import OpenAI
def batch_eval(query_file, result1_file, result2_file, output_file_path):
client = OpenAI()
with open(query_file, 'r') as f:
with open(query_file, "r") as f:
data = f.read()
queries = re.findall(r'- Question \d+: (.+)', data)
queries = re.findall(r"- Question \d+: (.+)", data)
with open(result1_file, 'r') as f:
with open(result1_file, "r") as f:
answers1 = json.load(f)
answers1 = [i['result'] for i in answers1]
answers1 = [i["result"] for i in answers1]
with open(result2_file, 'r') as f:
with open(result2_file, "r") as f:
answers2 = json.load(f)
answers2 = [i['result'] for i in answers2]
answers2 = [i["result"] for i in answers2]
requests = []
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
sys_prompt = f"""
sys_prompt = """
---Role---
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
"""
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
}}
"""
request_data = {
"custom_id": f"request-{i+1}",
"method": "POST",
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
"model": "gpt-4o-mini",
"messages": [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": prompt}
{"role": "user", "content": prompt},
],
}
},
}
requests.append(request_data)
with jsonlines.open(output_file_path, mode='w') as writer:
with jsonlines.open(output_file_path, mode="w") as writer:
for request in requests:
writer.write(request)
print(f"Batch API requests written to {output_file_path}")
batch_input_file = client.files.create(
file=open(output_file_path, "rb"),
purpose="batch"
file=open(output_file_path, "rb"), purpose="batch"
)
batch_input_file_id = batch_input_file.id
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
input_file_id=batch_input_file_id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"description": "nightly eval job"
}
metadata={"description": "nightly eval job"},
)
print(f'Batch {batch.id} has been created.')
print(f"Batch {batch.id} has been created.")
if __name__ == "__main__":
batch_eval()

View File

@@ -1,9 +1,8 @@
import os
from openai import OpenAI
# os.environ["OPENAI_API_KEY"] = ""
def openai_complete_if_cache(
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -47,9 +46,9 @@ if __name__ == "__main__":
...
"""
result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt)
result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
file_path = f"./queries.txt"
file_path = "./queries.txt"
with open(file_path, "w") as file:
file.write(result)

View File

@@ -20,13 +20,11 @@ rag = LightRAG(
llm_model_func=bedrock_complete,
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=bedrock_embedding
)
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
),
)
with open("./book.txt", 'r', encoding='utf-8') as f:
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
for mode in ["naive", "local", "global", "hybrid"]:
@@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]:
print(f"| {mode.capitalize()} |")
print("+-" + "-" * len(mode) + "-+\n")
print(
rag.query(
"What are the top themes in this story?",
param=QueryParam(mode=mode)
)
rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
)

View File

@@ -1,5 +1,4 @@
import os
import sys
from lightrag import LightRAG, QueryParam
from lightrag.llm import hf_model_complete, hf_embedding
@@ -14,15 +13,19 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete,
llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
texts,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
)
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
embed_model=AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
),
),
)
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -12,14 +12,11 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name='your_model_name',
llm_model_name="your_model_name",
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
),
)
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -10,6 +10,7 @@ WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -20,17 +21,19 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
base_url="https://api.upstage.ai/v1/solar",
)
# function test
async def test_funcs():
result = await llm_model_func("How are you?")
@@ -39,6 +42,7 @@ async def test_funcs():
result = await embedding_func(["How are you?"])
print("embedding_func: ", result)
asyncio.run(test_funcs())
@@ -46,10 +50,8 @@ rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,9 +1,7 @@
import os
import sys
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from transformers import AutoModel,AutoTokenizer
from lightrag.llm import gpt_4o_mini_complete
WORKING_DIR = "./dickens"
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete
llm_model_func=gpt_4o_mini_complete,
# llm_model_func=gpt_4o_complete
)
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,4 +1,4 @@
from .lightrag import LightRAG, QueryParam
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "0.0.6"
__author__ = "Zirui Guo"

View File

@@ -12,6 +12,7 @@ TextChunkSchema = TypedDict(
T = TypeVar("T")
@dataclass
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global"
@@ -36,6 +37,7 @@ class StorageNameSpace:
"""commit the storage operations after querying"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:

View File

@@ -3,10 +3,12 @@ import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from typing import Type, cast
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
from .llm import (
gpt_4o_mini_complete,
openai_embedding,
)
from .operate import (
chunking_by_token_size,
extract_entities,
@@ -37,6 +39,7 @@ from .base import (
QueryParam,
)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
@@ -69,7 +72,6 @@ class LightRAG:
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
@@ -83,7 +85,7 @@ class LightRAG:
# LLM
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
@@ -133,29 +135,23 @@ class LightRAG:
self.embedding_func
)
self.entities_vdb = (
self.vector_db_storage_cls(
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"}
meta_fields={"entity_name"},
)
)
self.relationships_vdb = (
self.vector_db_storage_cls(
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}
meta_fields={"src_id", "tgt_id"},
)
)
self.chunks_vdb = (
self.vector_db_storage_cls(
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
@@ -177,7 +173,7 @@ class LightRAG:
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning(f"All docs are already in the storage")
logger.warning("All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
@@ -203,7 +199,7 @@ class LightRAG:
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage")
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
@@ -291,7 +287,6 @@ class LightRAG:
await self._query_done()
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
@@ -299,5 +294,3 @@ class LightRAG:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)

View File

@@ -1,9 +1,7 @@
import os
import copy
import json
import botocore
import aioboto3
import botocore.errorfactory
import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -13,24 +11,34 @@ from tenacity import (
wait_exponential,
retry_if_exception_type,
)
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
import copy
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
@@ -64,43 +72,56 @@ class BedrockError(Exception):
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[],
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message['content'] = [{'text': message['content']}]
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({'role': "user", 'content': [{'text': prompt}]})
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {
'modelId': model,
'messages': messages
}
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args['system'] = [{'text': system_prompt}]
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
'max_tokens': "maxTokens",
'top_p': "topP",
'stop_sequences': "stopSequences"
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
args['inferenceConfig'] = {}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
if hashing_kv is not None:
await hashing_kv.upsert({
await hashing_kv.upsert(
{
args_hash: {
'return': response['output']['message']['content'][0]['text'],
'model': model
"return": response["output"]["message"]["content"][0]["text"],
"model": model,
}
})
}
)
return response["output"]["message"]["content"][0]["text"]
return response['output']['message']['content'][0]['text']
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = model
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
if hf_tokenizer.pad_token == None:
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
if hf_tokenizer.pad_token is None:
# print("use eos token")
hf_tokenizer.pad_token = hf_tokenizer.eos_token
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
input_prompt = ''
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
try:
ori_message = copy.deepcopy(messages)
if messages[0]['role'] == "system":
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
if messages[0]["role"] == "system":
messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
len_message = len(ori_message)
for msgid in range(len_message):
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
output = hf_model.generate(
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response_text, "model": model}}
)
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
return result
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -241,7 +286,7 @@ async def bedrock_complete(
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
model_name,
prompt,
@@ -250,10 +295,11 @@ async def hf_model_complete(
**kwargs,
)
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,
prompt,
@@ -262,17 +308,25 @@ async def ollama_model_complete(
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embedding(
texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps({
'inputText': text,
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
'embeddingTypes': ["float"]
})
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({
'inputText': text
})
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
@@ -315,29 +378,27 @@ async def bedrock_embedding(
modelId=model,
body=body,
accept="application/json",
contentType="application/json"
contentType="application/json",
)
response_body = await response.get('body').json()
response_body = await response.get("body").json()
embed_texts.append(response_body['embedding'])
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps({
'texts': texts,
'input_type': "search_document",
'truncate': "NONE"
})
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json"
contentType="application/json",
)
response_body = json.loads(response.get('body').read())
response_body = json.loads(response.get("body").read())
embed_texts = response_body['embeddings']
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
@@ -345,12 +406,15 @@ async def bedrock_embedding(
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
embed_text = []
for text in texts:
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
return embed_text
if __name__ == "__main__":
import asyncio
async def main():
result = await gpt_4o_mini_complete('How are you?')
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())

View File

@@ -25,6 +25,7 @@ from .base import (
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
@@ -45,6 +46,7 @@ def chunking_by_token_size(
)
return results
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
@@ -232,6 +234,7 @@ async def _merge_edges_then_upsert(
return edge_data
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
@@ -352,7 +355,9 @@ async def extract_entities(
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if not len(all_relationships_data):
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
logger.warning(
"Didn't extract any relationships, maybe your LLM is not working"
)
return None
if entity_vdb is not None:
@@ -370,7 +375,10 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
}
@@ -378,6 +386,7 @@ async def extract_entities(
return knwoledge_graph_inst
async def local_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -397,15 +406,20 @@ async def local_query(
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords)
except json.JSONDecodeError as e:
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords)
keywords = ", ".join(keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -431,10 +445,19 @@ async def local_query(
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_local_query_context(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -516,6 +539,7 @@ async def _build_local_query_context(
```
"""
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
)
return all_edges_data
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -628,15 +654,20 @@ async def global_query(
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords)
except json.JSONDecodeError as e:
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords)
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# Handle parsing error
@@ -666,10 +697,19 @@ async def global_query(
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_global_query_context(
keywords,
knowledge_graph_inst: BaseGraphStorage,
@@ -765,6 +805,7 @@ async def _build_global_query_context(
```
"""
async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
return node_datas
async def _find_related_text_unit_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas
@@ -822,9 +863,7 @@ async def _find_related_text_unit_from_relationships(
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
all_text_units = sorted(
all_text_units, key=lambda x: x["order"]
)
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
async def hybrid_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -855,18 +895,23 @@ async def hybrid_query(
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords)
ll_keywords = ', '.join(ll_keywords)
except json.JSONDecodeError as e:
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords)
ll_keywords = ', '.join(ll_keywords)
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -907,51 +952,76 @@ async def hybrid_query(
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings
def extract_sections(context):
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
entities_match = re.search(
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else ''
relationships = relationships_match.group(1) if relationships_match else ''
sources = sources_match.group(1) if sources_match else ''
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# Extract sections from both contexts
if high_level_context==None:
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities, hl_relationships, hl_sources = '','',''
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context==None:
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
ll_entities, ll_relationships, ll_sources = '','',''
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
combined_entities = '\n'.join(combined_entities_set)
combined_entities_set = set(
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
)
combined_entities = "\n".join(combined_entities_set)
# Combine and deduplicate the relationships
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
combined_relationships = '\n'.join(combined_relationships_set)
combined_relationships_set = set(
filter(
None,
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
)
)
combined_relationships = "\n".join(combined_relationships_set)
# Combine and deduplicate the sources
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
combined_sources = '\n'.join(combined_sources_set)
combined_sources_set = set(
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
)
combined_sources = "\n".join(combined_sources_set)
# Format the combined context
return f"""
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
{combined_sources}
"""
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
@@ -997,7 +1068,15 @@ async def naive_query(
)
if len(response) > len(sys_prompt):
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response

View File

@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
PROMPTS[
"entity_extraction"
] = """-Goal-
PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps-
@@ -146,9 +144,7 @@ PROMPTS[
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
PROMPTS[
"rag_response"
] = """---Role---
PROMPTS["rag_response"] = """---Role---
You are a helpful assistant responding to questions about data in the tables provided.
@@ -241,9 +237,7 @@ Output:
"""
PROMPTS[
"naive_rag_response"
] = """You're a helpful assistant
PROMPTS["naive_rag_response"] = """You're a helpful assistant
Below are the knowledge you know:
{content_data}
---

View File

@@ -1,16 +1,11 @@
import asyncio
import html
import json
import os
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Union, cast
import pickle
import hnswlib
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
import xxhash
from .utils import load_json, logger, write_json
from .base import (
@@ -19,6 +14,7 @@ from .base import (
BaseVectorStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self):
self._data = {}
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self):
self._client.save()
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)

View File

@@ -16,18 +16,22 @@ ENCODER = None
logger = logging.getLogger("lightrag")
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(file_handler)
@dataclass
class EmbeddingFunc:
embedding_dim: int
@@ -37,6 +41,7 @@ class EmbeddingFunc:
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
else:
return None
def convert_response_to_json(response: str) -> dict:
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
return final_decro
def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
def write_json(json_obj, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
content = ENCODER.decode(tokens)
return content
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
if max_token_size <= 0:
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i]
return list_data
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
)
def save_data_to_file(data, file_name):
with open(file_name, 'w', encoding='utf-8') as f:
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)

View File

@@ -3,11 +3,11 @@ import json
import glob
import argparse
def extract_unique_contexts(input_directory, output_directory):
def extract_unique_contexts(input_directory, output_directory):
os.makedirs(output_directory, exist_ok=True)
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
print(f"Found {len(jsonl_files)} JSONL files.")
for file_path in jsonl_files:
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
print(f"Processing file: {filename}")
try:
with open(file_path, 'r', encoding='utf-8') as infile:
with open(file_path, "r", encoding="utf-8") as infile:
for line_number, line in enumerate(infile, start=1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
context = json_obj.get('context')
context = json_obj.get("context")
if context and context not in unique_contexts_dict:
unique_contexts_dict[context] = None
except json.JSONDecodeError as e:
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
print(
f"JSON decoding error in file {filename} at line {line_number}: {e}"
)
except FileNotFoundError:
print(f"File not found: {filename}")
continue
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
continue
unique_contexts_list = list(unique_contexts_dict.keys())
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
print(
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
)
try:
with open(output_path, 'w', encoding='utf-8') as outfile:
with open(output_path, "w", encoding="utf-8") as outfile:
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
print(f"Unique `context` entries have been saved to: {output_filename}")
except Exception as e:
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
parser.add_argument(
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
)
args = parser.parse_args()

View File

@@ -4,8 +4,9 @@ import time
from lightrag import LightRAG
def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
retries = 0
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
cls = "agriculture"
WORKING_DIR = "../{cls}"

View File

@@ -7,6 +7,7 @@ from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm import openai_complete_if_cache, openai_embedding
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
@@ -19,20 +20,24 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
base_url="https://api.upstage.ai/v1/solar",
)
## /For Upstage API
def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
retries = 0
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
cls = "mix"
WORKING_DIR = f"../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR,
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")

View File

@@ -1,8 +1,8 @@
import os
import json
from openai import OpenAI
from transformers import GPT2Tokenizer
def openai_complete_if_cache(
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -19,7 +19,9 @@ def openai_complete_if_cache(
)
return response.choices[0].message.content
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def get_summary(context, tot_tokens=2000):
tokens = tokenizer.tokenize(context)
@@ -34,9 +36,9 @@ def get_summary(context, tot_tokens=2000):
return summary
clses = ['agriculture']
clses = ["agriculture"]
for cls in clses:
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
unique_contexts = json.load(f)
summaries = [get_summary(context) for context in unique_contexts]
@@ -67,7 +69,7 @@ for cls in clses:
...
"""
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
file_path = f"../datasets/questions/{cls}_questions.txt"
with open(file_path, "w") as file:

View File

@@ -4,16 +4,18 @@ import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
def extract_queries(file_path):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
data = f.read()
data = data.replace('**', '')
data = data.replace("**", "")
queries = re.findall(r'- Question \d+: (.+)', data)
queries = re.findall(r"- Question \d+: (.+)", data)
return queries
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
except Exception as e:
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):
loop = always_get_an_event_loop()
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
with open(output_file, "a", encoding="utf-8") as result_file, open(
error_file, "a", encoding="utf-8"
) as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
result, error = loop.run_until_complete(
process_query(query_text, rag_instance, query_param)
)
if result:
if not first_entry:
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
result_file.write("\n]")
if __name__ == "__main__":
cls = "agriculture"
mode = "hybrid"
@@ -59,4 +70,6 @@ if __name__ == "__main__":
query_param = QueryParam(mode=mode)
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
run_queries_and_save_to_json(
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
)

View File

@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
@@ -20,28 +21,33 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
base_url="https://api.upstage.ai/v1/solar",
)
## /For Upstage API
def extract_queries(file_path):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
data = f.read()
data = data.replace('**', '')
data = data.replace("**", "")
queries = re.findall(r'- Question \d+: (.+)', data)
queries = re.findall(r"- Question \d+: (.+)", data)
return queries
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
except Exception as e:
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):
loop = always_get_an_event_loop()
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
with open(output_file, "a", encoding="utf-8") as result_file, open(
error_file, "a", encoding="utf-8"
) as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
result, error = loop.run_until_complete(
process_query(query_text, rag_instance, query_param)
)
if result:
if not first_entry:
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
result_file.write("\n]")
if __name__ == "__main__":
cls = "mix"
mode = "hybrid"
WORKING_DIR = f"../{cls}"
rag = LightRAG(working_dir=WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR,
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
query_param = QueryParam(mode=mode)
base_dir='../datasets/questions'
base_dir = "../datasets/questions"
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
run_queries_and_save_to_json(
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
)

View File

@@ -1,13 +1,13 @@
aioboto3
openai
tiktoken
networkx
graspologic
nano-vectordb
hnswlib
xxhash
tenacity
transformers
torch
ollama
accelerate
aioboto3
graspologic
hnswlib
nano-vectordb
networkx
ollama
openai
tenacity
tiktoken
torch
transformers
xxhash