chore: added pre-commit-hooks and ruff formatting for commit-hooks

This commit is contained in:
Sanketh Kumar
2024-10-19 09:43:17 +05:30
parent 99bd644bf7
commit 32464fab4e
26 changed files with 635 additions and 393 deletions

View File

@@ -3,11 +3,11 @@ import json
import glob
import argparse
def extract_unique_contexts(input_directory, output_directory):
def extract_unique_contexts(input_directory, output_directory):
os.makedirs(output_directory, exist_ok=True)
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
print(f"Found {len(jsonl_files)} JSONL files.")
for file_path in jsonl_files:
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
print(f"Processing file: {filename}")
try:
with open(file_path, 'r', encoding='utf-8') as infile:
with open(file_path, "r", encoding="utf-8") as infile:
for line_number, line in enumerate(infile, start=1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
context = json_obj.get('context')
context = json_obj.get("context")
if context and context not in unique_contexts_dict:
unique_contexts_dict[context] = None
except json.JSONDecodeError as e:
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
print(
f"JSON decoding error in file {filename} at line {line_number}: {e}"
)
except FileNotFoundError:
print(f"File not found: {filename}")
continue
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
continue
unique_contexts_list = list(unique_contexts_dict.keys())
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
print(
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
)
try:
with open(output_path, 'w', encoding='utf-8') as outfile:
with open(output_path, "w", encoding="utf-8") as outfile:
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
print(f"Unique `context` entries have been saved to: {output_filename}")
except Exception as e:
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
parser.add_argument(
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
)
args = parser.parse_args()

View File

@@ -4,10 +4,11 @@ import time
from lightrag import LightRAG
def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
retries = 0
max_retries = 3
while retries < max_retries:
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
cls = "agriculture"
WORKING_DIR = "../{cls}"
@@ -29,4 +31,4 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(working_dir=WORKING_DIR)
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")

View File

@@ -7,6 +7,7 @@ from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm import openai_complete_if_cache, openai_embedding
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
@@ -19,22 +20,26 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
base_url="https://api.upstage.ai/v1/solar",
)
## /For Upstage API
def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
retries = 0
max_retries = 3
while retries < max_retries:
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
cls = "mix"
WORKING_DIR = f"../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")

View File

@@ -1,8 +1,8 @@
import os
import json
from openai import OpenAI
from transformers import GPT2Tokenizer
def openai_complete_if_cache(
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -19,24 +19,26 @@ def openai_complete_if_cache(
)
return response.choices[0].message.content
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def get_summary(context, tot_tokens=2000):
tokens = tokenizer.tokenize(context)
half_tokens = tot_tokens // 2
start_tokens = tokens[1000:1000 + half_tokens]
end_tokens = tokens[-(1000 + half_tokens):1000]
start_tokens = tokens[1000 : 1000 + half_tokens]
end_tokens = tokens[-(1000 + half_tokens) : 1000]
summary_tokens = start_tokens + end_tokens
summary = tokenizer.convert_tokens_to_string(summary_tokens)
return summary
clses = ['agriculture']
clses = ["agriculture"]
for cls in clses:
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
unique_contexts = json.load(f)
summaries = [get_summary(context) for context in unique_contexts]
@@ -67,10 +69,10 @@ for cls in clses:
...
"""
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
file_path = f"../datasets/questions/{cls}_questions.txt"
with open(file_path, "w") as file:
file.write(result)
print(f"{cls}_questions written to {file_path}")
print(f"{cls}_questions written to {file_path}")

View File

@@ -4,16 +4,18 @@ import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
def extract_queries(file_path):
with open(file_path, 'r') as f:
data = f.read()
data = data.replace('**', '')
queries = re.findall(r'- Question \d+: (.+)', data)
def extract_queries(file_path):
with open(file_path, "r") as f:
data = f.read()
data = data.replace("**", "")
queries = re.findall(r"- Question \d+: (.+)", data)
return queries
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
except Exception as e:
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):
loop = always_get_an_event_loop()
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
with open(output_file, "a", encoding="utf-8") as result_file, open(
error_file, "a", encoding="utf-8"
) as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
result, error = loop.run_until_complete(
process_query(query_text, rag_instance, query_param)
)
if result:
if not first_entry:
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
result_file.write("\n]")
if __name__ == "__main__":
cls = "agriculture"
mode = "hybrid"
@@ -59,4 +70,6 @@ if __name__ == "__main__":
query_param = QueryParam(mode=mode)
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
run_queries_and_save_to_json(
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
)

View File

@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
@@ -20,28 +21,33 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
base_url="https://api.upstage.ai/v1/solar",
)
## /For Upstage API
def extract_queries(file_path):
with open(file_path, 'r') as f:
data = f.read()
data = data.replace('**', '')
queries = re.findall(r'- Question \d+: (.+)', data)
def extract_queries(file_path):
with open(file_path, "r") as f:
data = f.read()
data = data.replace("**", "")
queries = re.findall(r"- Question \d+: (.+)", data)
return queries
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
except Exception as e:
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):
loop = always_get_an_event_loop()
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
with open(output_file, "a", encoding="utf-8") as result_file, open(
error_file, "a", encoding="utf-8"
) as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
result, error = loop.run_until_complete(
process_query(query_text, rag_instance, query_param)
)
if result:
if not first_entry:
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
result_file.write("\n]")
if __name__ == "__main__":
cls = "mix"
mode = "hybrid"
WORKING_DIR = f"../{cls}"
rag = LightRAG(working_dir=WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
query_param = QueryParam(mode=mode)
base_dir='../datasets/questions'
base_dir = "../datasets/questions"
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
run_queries_and_save_to_json(
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
)