update Step_3.py and openai compatible script
This commit is contained in:
66
reproduce/Step_1_openai_compatible.py
Normal file
66
reproduce/Step_1_openai_compatible.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lightrag import LightRAG
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
|
|
||||||
|
## For Upstage API
|
||||||
|
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"solar-mini",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await openai_embedding(
|
||||||
|
texts,
|
||||||
|
model="solar-embedding-1-large-query",
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar"
|
||||||
|
)
|
||||||
|
## /For Upstage API
|
||||||
|
|
||||||
|
def insert_text(rag, file_path):
|
||||||
|
with open(file_path, mode='r') as f:
|
||||||
|
unique_contexts = json.load(f)
|
||||||
|
|
||||||
|
retries = 0
|
||||||
|
max_retries = 3
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
rag.insert(unique_contexts)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
retries += 1
|
||||||
|
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
|
||||||
|
time.sleep(10)
|
||||||
|
if retries == max_retries:
|
||||||
|
print("Insertion failed after exceeding the maximum number of retries")
|
||||||
|
|
||||||
|
cls = "mix"
|
||||||
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
rag = LightRAG(working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=4096,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=embedding_func
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
@@ -53,10 +53,10 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cls = "agriculture"
|
cls = "agriculture"
|
||||||
mode = "hybrid"
|
mode = "hybrid"
|
||||||
WORKING_DIR = "../{cls}"
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
rag = LightRAG(working_dir=WORKING_DIR)
|
rag = LightRAG(working_dir=WORKING_DIR)
|
||||||
query_param = QueryParam(mode=mode)
|
query_param = QueryParam(mode=mode)
|
||||||
|
|
||||||
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
||||||
run_queries_and_save_to_json(queries, rag, query_param, "result.json", "errors.json")
|
run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
|
||||||
|
99
reproduce/Step_3_openai_compatible.py
Normal file
99
reproduce/Step_3_openai_compatible.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from tqdm import tqdm
|
||||||
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
## For Upstage API
|
||||||
|
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"solar-mini",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await openai_embedding(
|
||||||
|
texts,
|
||||||
|
model="solar-embedding-1-large-query",
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar"
|
||||||
|
)
|
||||||
|
## /For Upstage API
|
||||||
|
|
||||||
|
def extract_queries(file_path):
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
data = data.replace('**', '')
|
||||||
|
|
||||||
|
queries = re.findall(r'- Question \d+: (.+)', data)
|
||||||
|
|
||||||
|
return queries
|
||||||
|
|
||||||
|
async def process_query(query_text, rag_instance, query_param):
|
||||||
|
try:
|
||||||
|
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||||
|
return {"query": query_text, "result": result, "context": context}, None
|
||||||
|
except Exception as e:
|
||||||
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
|
||||||
|
loop = always_get_an_event_loop()
|
||||||
|
|
||||||
|
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
|
||||||
|
result_file.write("[\n")
|
||||||
|
first_entry = True
|
||||||
|
|
||||||
|
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||||
|
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
|
||||||
|
|
||||||
|
if result:
|
||||||
|
if not first_entry:
|
||||||
|
result_file.write(",\n")
|
||||||
|
json.dump(result, result_file, ensure_ascii=False, indent=4)
|
||||||
|
first_entry = False
|
||||||
|
elif error:
|
||||||
|
json.dump(error, err_file, ensure_ascii=False, indent=4)
|
||||||
|
err_file.write("\n")
|
||||||
|
|
||||||
|
result_file.write("\n]")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cls = "mix"
|
||||||
|
mode = "hybrid"
|
||||||
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
|
rag = LightRAG(working_dir=WORKING_DIR)
|
||||||
|
rag = LightRAG(working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=4096,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=embedding_func
|
||||||
|
)
|
||||||
|
)
|
||||||
|
query_param = QueryParam(mode=mode)
|
||||||
|
|
||||||
|
base_dir='../datasets/questions'
|
||||||
|
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
||||||
|
run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
|
Reference in New Issue
Block a user