chore: added pre-commit-hooks and ruff formatting for commit-hooks
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,3 +2,4 @@ __pycache__
|
||||
*.egg-info
|
||||
dickens/
|
||||
book.txt
|
||||
lightrag-dev/
|
||||
|
22
.pre-commit-config.yaml
Normal file
22
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.4
|
||||
hooks:
|
||||
- id: ruff-format
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
|
||||
- repo: https://github.com/mgedmin/check-manifest
|
||||
rev: "0.49"
|
||||
hooks:
|
||||
- id: check-manifest
|
||||
stages: [manual]
|
@@ -83,7 +83,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
|
||||
<details>
|
||||
<summary> Using Open AI-like APIs </summary>
|
||||
|
||||
LightRAG also support Open AI-like chat/embeddings APIs:
|
||||
LightRAG also supports Open AI-like chat/embeddings APIs:
|
||||
```python
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
@@ -187,10 +187,10 @@ with open("./newText.txt") as f:
|
||||
```
|
||||
## Evaluation
|
||||
### Dataset
|
||||
The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
||||
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
||||
|
||||
### Generate Query
|
||||
LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
|
||||
LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
|
||||
|
||||
<details>
|
||||
<summary> Prompt </summary>
|
||||
@@ -384,7 +384,7 @@ def insert_text(rag, file_path):
|
||||
|
||||
### Step-2 Generate Queries
|
||||
|
||||
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
|
||||
We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
|
||||
|
||||
<details>
|
||||
<summary> Code </summary>
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import jsonlines
|
||||
@@ -9,22 +8,22 @@ from openai import OpenAI
|
||||
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
||||
client = OpenAI()
|
||||
|
||||
with open(query_file, 'r') as f:
|
||||
with open(query_file, "r") as f:
|
||||
data = f.read()
|
||||
|
||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
||||
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||
|
||||
with open(result1_file, 'r') as f:
|
||||
with open(result1_file, "r") as f:
|
||||
answers1 = json.load(f)
|
||||
answers1 = [i['result'] for i in answers1]
|
||||
answers1 = [i["result"] for i in answers1]
|
||||
|
||||
with open(result2_file, 'r') as f:
|
||||
with open(result2_file, "r") as f:
|
||||
answers2 = json.load(f)
|
||||
answers2 = [i['result'] for i in answers2]
|
||||
answers2 = [i["result"] for i in answers2]
|
||||
|
||||
requests = []
|
||||
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
||||
sys_prompt = f"""
|
||||
sys_prompt = """
|
||||
---Role---
|
||||
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
||||
"""
|
||||
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
request_data = {
|
||||
"custom_id": f"request-{i+1}",
|
||||
"method": "POST",
|
||||
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
requests.append(request_data)
|
||||
|
||||
with jsonlines.open(output_file_path, mode='w') as writer:
|
||||
with jsonlines.open(output_file_path, mode="w") as writer:
|
||||
for request in requests:
|
||||
writer.write(request)
|
||||
|
||||
print(f"Batch API requests written to {output_file_path}")
|
||||
|
||||
batch_input_file = client.files.create(
|
||||
file=open(output_file_path, "rb"),
|
||||
purpose="batch"
|
||||
file=open(output_file_path, "rb"), purpose="batch"
|
||||
)
|
||||
batch_input_file_id = batch_input_file.id
|
||||
|
||||
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
||||
input_file_id=batch_input_file_id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={
|
||||
"description": "nightly eval job"
|
||||
}
|
||||
metadata={"description": "nightly eval job"},
|
||||
)
|
||||
|
||||
print(f'Batch {batch.id} has been created.')
|
||||
print(f"Batch {batch.id} has been created.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_eval()
|
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# os.environ["OPENAI_API_KEY"] = ""
|
||||
|
||||
|
||||
def openai_complete_if_cache(
|
||||
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -47,9 +46,9 @@ if __name__ == "__main__":
|
||||
...
|
||||
"""
|
||||
|
||||
result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt)
|
||||
result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
|
||||
|
||||
file_path = f"./queries.txt"
|
||||
file_path = "./queries.txt"
|
||||
with open(file_path, "w") as file:
|
||||
file.write(result)
|
||||
|
||||
|
@@ -20,13 +20,11 @@ rag = LightRAG(
|
||||
llm_model_func=bedrock_complete,
|
||||
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
func=bedrock_embedding
|
||||
)
|
||||
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
|
||||
),
|
||||
)
|
||||
|
||||
with open("./book.txt", 'r', encoding='utf-8') as f:
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
for mode in ["naive", "local", "global", "hybrid"]:
|
||||
@@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]:
|
||||
print(f"| {mode.capitalize()} |")
|
||||
print("+-" + "-" * len(mode) + "-+\n")
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(mode=mode)
|
||||
)
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
|
||||
)
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import hf_model_complete, hf_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from transformers import AutoModel,AutoTokenizer
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
@@ -14,15 +13,19 @@ if not os.path.exists(WORKING_DIR):
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=hf_model_complete,
|
||||
llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
|
||||
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
max_token_size=5000,
|
||||
func=lambda texts: hf_embedding(
|
||||
texts,
|
||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||
)
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L6-v2"
|
||||
),
|
||||
embed_model=AutoModel.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L6-v2"
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
||||
|
@@ -12,14 +12,11 @@ if not os.path.exists(WORKING_DIR):
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=ollama_model_complete,
|
||||
llm_model_name='your_model_name',
|
||||
llm_model_name="your_model_name",
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=768,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: ollama_embedding(
|
||||
texts,
|
||||
embed_model="nomic-embed-text"
|
||||
)
|
||||
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
||||
|
@@ -10,6 +10,7 @@ WORKING_DIR = "./dickens"
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -20,17 +21,19 @@ async def llm_model_func(
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embedding(
|
||||
texts,
|
||||
model="solar-embedding-1-large-query",
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar"
|
||||
base_url="https://api.upstage.ai/v1/solar",
|
||||
)
|
||||
|
||||
|
||||
# function test
|
||||
async def test_funcs():
|
||||
result = await llm_model_func("How are you?")
|
||||
@@ -39,6 +42,7 @@ async def test_funcs():
|
||||
result = await embedding_func(["How are you?"])
|
||||
print("embedding_func: ", result)
|
||||
|
||||
|
||||
asyncio.run(test_funcs())
|
||||
|
||||
|
||||
@@ -46,10 +50,8 @@ rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=4096,
|
||||
max_token_size=8192,
|
||||
func=embedding_func
|
||||
)
|
||||
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
||||
|
@@ -1,9 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||
from transformers import AutoModel,AutoTokenizer
|
||||
from lightrag.llm import gpt_4o_mini_complete
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
# llm_model_func=gpt_4o_complete
|
||||
)
|
||||
|
||||
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from .lightrag import LightRAG, QueryParam
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "0.0.6"
|
||||
__author__ = "Zirui Guo"
|
||||
|
@@ -12,6 +12,7 @@ TextChunkSchema = TypedDict(
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryParam:
|
||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||
@@ -36,6 +37,7 @@ class StorageNameSpace:
|
||||
"""commit the storage operations after querying"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVectorStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||
async def all_keys(self) -> list[str]:
|
||||
|
@@ -3,10 +3,12 @@ import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Type, cast, Any
|
||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import Type, cast
|
||||
|
||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
|
||||
from .llm import (
|
||||
gpt_4o_mini_complete,
|
||||
openai_embedding,
|
||||
)
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
@@ -37,6 +39,7 @@ from .base import (
|
||||
QueryParam,
|
||||
)
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -69,7 +72,6 @@ class LightRAG:
|
||||
"dimensions": 1536,
|
||||
"num_walks": 10,
|
||||
"walk_length": 40,
|
||||
"num_walks": 10,
|
||||
"window_size": 2,
|
||||
"iterations": 3,
|
||||
"random_seed": 3,
|
||||
@@ -77,13 +79,13 @@ class LightRAG:
|
||||
)
|
||||
|
||||
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
|
||||
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
|
||||
@@ -133,28 +135,22 @@ class LightRAG:
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
self.entities_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"}
|
||||
)
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
self.relationships_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="relationships",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"}
|
||||
)
|
||||
self.relationships_vdb = self.vector_db_storage_cls(
|
||||
namespace="relationships",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
self.chunks_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="chunks",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunks_vdb = self.vector_db_storage_cls(
|
||||
namespace="chunks",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
@@ -177,7 +173,7 @@ class LightRAG:
|
||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
if not len(new_docs):
|
||||
logger.warning(f"All docs are already in the storage")
|
||||
logger.warning("All docs are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
@@ -203,7 +199,7 @@ class LightRAG:
|
||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||
}
|
||||
if not len(inserting_chunks):
|
||||
logger.warning(f"All chunks are already in the storage")
|
||||
logger.warning("All chunks are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||
|
||||
@@ -291,7 +287,6 @@ class LightRAG:
|
||||
await self._query_done()
|
||||
return response
|
||||
|
||||
|
||||
async def _query_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [self.llm_response_cache]:
|
||||
@@ -299,5 +294,3 @@ class LightRAG:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
|
221
lightrag/llm.py
221
lightrag/llm.py
@@ -1,9 +1,7 @@
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import botocore
|
||||
import aioboto3
|
||||
import botocore.errorfactory
|
||||
import numpy as np
|
||||
import ollama
|
||||
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
||||
@@ -13,24 +11,34 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from .base import BaseKVStorage
|
||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
import copy
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def openai_complete_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
@@ -64,43 +72,56 @@ class BedrockError(Exception):
|
||||
retry=retry_if_exception_type((BedrockError)),
|
||||
)
|
||||
async def bedrock_complete_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[],
|
||||
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
|
||||
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
|
||||
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
# Fix message history format
|
||||
messages = []
|
||||
for history_message in history_messages:
|
||||
message = copy.copy(history_message)
|
||||
message['content'] = [{'text': message['content']}]
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
messages.append(message)
|
||||
|
||||
# Add user prompt
|
||||
messages.append({'role': "user", 'content': [{'text': prompt}]})
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
|
||||
# Initialize Converse API arguments
|
||||
args = {
|
||||
'modelId': model,
|
||||
'messages': messages
|
||||
}
|
||||
args = {"modelId": model, "messages": messages}
|
||||
|
||||
# Define system prompt
|
||||
if system_prompt:
|
||||
args['system'] = [{'text': system_prompt}]
|
||||
args["system"] = [{"text": system_prompt}]
|
||||
|
||||
# Map and set up inference parameters
|
||||
inference_params_map = {
|
||||
'max_tokens': "maxTokens",
|
||||
'top_p': "topP",
|
||||
'stop_sequences': "stopSequences"
|
||||
"max_tokens": "maxTokens",
|
||||
"top_p": "topP",
|
||||
"stop_sequences": "stopSequences",
|
||||
}
|
||||
if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
|
||||
args['inferenceConfig'] = {}
|
||||
if inference_params := list(
|
||||
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
||||
):
|
||||
args["inferenceConfig"] = {}
|
||||
for param in inference_params:
|
||||
args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||
kwargs.pop(param)
|
||||
)
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
if hashing_kv is not None:
|
||||
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
|
||||
# Call model via Converse API
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
|
||||
try:
|
||||
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||
except Exception as e:
|
||||
raise BedrockError(e)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({
|
||||
args_hash: {
|
||||
'return': response['output']['message']['content'][0]['text'],
|
||||
'model': model
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": response["output"]["message"]["content"][0]["text"],
|
||||
"model": model,
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
return response['output']['message']['content'][0]['text']
|
||||
|
||||
async def hf_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = model
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
|
||||
if hf_tokenizer.pad_token == None:
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
||||
if hf_tokenizer.pad_token is None:
|
||||
# print("use eos token")
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
input_prompt = ''
|
||||
input_prompt = ""
|
||||
try:
|
||||
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
except:
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
ori_message = copy.deepcopy(messages)
|
||||
if messages[0]['role'] == "system":
|
||||
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
|
||||
if messages[0]["role"] == "system":
|
||||
messages[1]["content"] = (
|
||||
"<system>"
|
||||
+ messages[0]["content"]
|
||||
+ "</system>\n"
|
||||
+ messages[1]["content"]
|
||||
)
|
||||
messages = messages[1:]
|
||||
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
except:
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
len_message = len(ori_message)
|
||||
for msgid in range(len_message):
|
||||
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
|
||||
input_prompt = (
|
||||
input_prompt
|
||||
+ "<"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">"
|
||||
+ ori_message[msgid]["content"]
|
||||
+ "</"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">\n"
|
||||
)
|
||||
|
||||
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
|
||||
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
|
||||
input_ids = hf_tokenizer(
|
||||
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||
).to("cuda")
|
||||
output = hf_model.generate(
|
||||
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
|
||||
)
|
||||
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response_text, "model": model}}
|
||||
)
|
||||
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
||||
return response_text
|
||||
|
||||
|
||||
async def ollama_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -241,7 +286,7 @@ async def bedrock_complete(
|
||||
async def hf_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await hf_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
@@ -250,10 +295,11 @@ async def hf_model_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await ollama_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
@@ -262,17 +308,25 @@ async def ollama_model_complete(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
|
||||
async def openai_embedding(
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
|
||||
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
||||
# )
|
||||
async def bedrock_embedding(
|
||||
texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
|
||||
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
|
||||
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
|
||||
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
|
||||
texts: list[str],
|
||||
model: str = "amazon.titan-embed-text-v2:0",
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
) -> np.ndarray:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
if "v2" in model:
|
||||
body = json.dumps({
|
||||
'inputText': text,
|
||||
# 'dimensions': embedding_dim,
|
||||
'embeddingTypes': ["float"]
|
||||
})
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({
|
||||
'inputText': text
|
||||
})
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise ValueError(f"Model {model} is not supported!")
|
||||
|
||||
@@ -315,29 +378,27 @@ async def bedrock_embedding(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json"
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = await response.get('body').json()
|
||||
response_body = await response.get("body").json()
|
||||
|
||||
embed_texts.append(response_body['embedding'])
|
||||
embed_texts.append(response_body["embedding"])
|
||||
elif model_provider == "cohere":
|
||||
body = json.dumps({
|
||||
'texts': texts,
|
||||
'input_type': "search_document",
|
||||
'truncate': "NONE"
|
||||
})
|
||||
body = json.dumps(
|
||||
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||
)
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json"
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get('body').read())
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
embed_texts = response_body['embeddings']
|
||||
embed_texts = response_body["embeddings"]
|
||||
else:
|
||||
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
||||
|
||||
@@ -345,12 +406,15 @@ async def bedrock_embedding(
|
||||
|
||||
|
||||
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
|
||||
input_ids = tokenizer(
|
||||
texts, return_tensors="pt", padding=True, truncation=True
|
||||
).input_ids
|
||||
with torch.no_grad():
|
||||
outputs = embed_model(input_ids)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
return embeddings.detach().numpy()
|
||||
|
||||
|
||||
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
||||
embed_text = []
|
||||
for text in texts:
|
||||
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
||||
|
||||
return embed_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
result = await gpt_4o_mini_complete('How are you?')
|
||||
result = await gpt_4o_mini_complete("How are you?")
|
||||
print(result)
|
||||
|
||||
asyncio.run(main())
|
||||
|
@@ -25,6 +25,7 @@ from .base import (
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||
):
|
||||
@@ -45,6 +46,7 @@ def chunking_by_token_size(
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _handle_entity_relation_summary(
|
||||
entity_or_relation_name: str,
|
||||
description: str,
|
||||
@@ -232,6 +234,7 @@ async def _merge_edges_then_upsert(
|
||||
|
||||
return edge_data
|
||||
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
@@ -352,7 +355,9 @@ async def extract_entities(
|
||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||
return None
|
||||
if not len(all_relationships_data):
|
||||
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
|
||||
logger.warning(
|
||||
"Didn't extract any relationships, maybe your LLM is not working"
|
||||
)
|
||||
return None
|
||||
|
||||
if entity_vdb is not None:
|
||||
@@ -370,7 +375,10 @@ async def extract_entities(
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
@@ -378,6 +386,7 @@ async def extract_entities(
|
||||
|
||||
return knwoledge_graph_inst
|
||||
|
||||
|
||||
async def local_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -397,15 +406,20 @@ async def local_query(
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("low_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
keywords = ", ".join(keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("low_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
keywords = ", ".join(keywords)
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
@@ -430,11 +444,20 @@ async def local_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _build_local_query_context(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -516,6 +539,7 @@ async def _build_local_query_context(
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
async def _find_most_related_text_unit_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
||||
return all_text_units
|
||||
|
||||
|
||||
async def _find_most_related_edges_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
|
||||
)
|
||||
return all_edges_data
|
||||
|
||||
|
||||
async def global_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -628,15 +654,20 @@ async def global_query(
|
||||
try:
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("high_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
keywords = ", ".join(keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
keywords = keywords_data.get("high_level_keywords", [])
|
||||
keywords = ', '.join(keywords)
|
||||
keywords = ", ".join(keywords)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle parsing error
|
||||
@@ -665,11 +696,20 @@ async def global_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _build_global_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -765,6 +805,7 @@ async def _build_global_query_context(
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
async def _find_most_related_entities_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
|
||||
|
||||
return node_datas
|
||||
|
||||
|
||||
async def _find_related_text_unit_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
):
|
||||
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
for dp in edge_datas
|
||||
@@ -822,9 +863,7 @@ async def _find_related_text_unit_from_relationships(
|
||||
all_text_units = [
|
||||
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
||||
]
|
||||
all_text_units = sorted(
|
||||
all_text_units, key=lambda x: x["order"]
|
||||
)
|
||||
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
||||
all_text_units = truncate_list_by_token_size(
|
||||
all_text_units,
|
||||
key=lambda x: x["data"]["content"],
|
||||
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
|
||||
|
||||
return all_text_units
|
||||
|
||||
|
||||
async def hybrid_query(
|
||||
query,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -855,18 +895,23 @@ async def hybrid_query(
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
hl_keywords = ', '.join(hl_keywords)
|
||||
ll_keywords = ', '.join(ll_keywords)
|
||||
except json.JSONDecodeError as e:
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = (
|
||||
result.replace(kw_prompt[:-1], "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.strip()
|
||||
)
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
hl_keywords = ', '.join(hl_keywords)
|
||||
ll_keywords = ', '.join(ll_keywords)
|
||||
hl_keywords = ", ".join(hl_keywords)
|
||||
ll_keywords = ", ".join(ll_keywords)
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
@@ -906,52 +951,77 @@ async def hybrid_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def combine_contexts(high_level_context, low_level_context):
|
||||
# Function to extract entities, relationships, and sources from context strings
|
||||
|
||||
def extract_sections(context):
|
||||
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||
entities_match = re.search(
|
||||
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
relationships_match = re.search(
|
||||
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
sources_match = re.search(
|
||||
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||
)
|
||||
|
||||
entities = entities_match.group(1) if entities_match else ''
|
||||
relationships = relationships_match.group(1) if relationships_match else ''
|
||||
sources = sources_match.group(1) if sources_match else ''
|
||||
entities = entities_match.group(1) if entities_match else ""
|
||||
relationships = relationships_match.group(1) if relationships_match else ""
|
||||
sources = sources_match.group(1) if sources_match else ""
|
||||
|
||||
return entities, relationships, sources
|
||||
|
||||
# Extract sections from both contexts
|
||||
|
||||
if high_level_context==None:
|
||||
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
||||
hl_entities, hl_relationships, hl_sources = '','',''
|
||||
if high_level_context is None:
|
||||
warnings.warn(
|
||||
"High Level context is None. Return empty High entity/relationship/source"
|
||||
)
|
||||
hl_entities, hl_relationships, hl_sources = "", "", ""
|
||||
else:
|
||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||
|
||||
|
||||
if low_level_context==None:
|
||||
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
||||
ll_entities, ll_relationships, ll_sources = '','',''
|
||||
if low_level_context is None:
|
||||
warnings.warn(
|
||||
"Low Level context is None. Return empty Low entity/relationship/source"
|
||||
)
|
||||
ll_entities, ll_relationships, ll_sources = "", "", ""
|
||||
else:
|
||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||
|
||||
|
||||
|
||||
# Combine and deduplicate the entities
|
||||
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
||||
combined_entities = '\n'.join(combined_entities_set)
|
||||
combined_entities_set = set(
|
||||
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
|
||||
)
|
||||
combined_entities = "\n".join(combined_entities_set)
|
||||
|
||||
# Combine and deduplicate the relationships
|
||||
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
|
||||
combined_relationships = '\n'.join(combined_relationships_set)
|
||||
combined_relationships_set = set(
|
||||
filter(
|
||||
None,
|
||||
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
|
||||
)
|
||||
)
|
||||
combined_relationships = "\n".join(combined_relationships_set)
|
||||
|
||||
# Combine and deduplicate the sources
|
||||
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
|
||||
combined_sources = '\n'.join(combined_sources_set)
|
||||
combined_sources_set = set(
|
||||
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
|
||||
)
|
||||
combined_sources = "\n".join(combined_sources_set)
|
||||
|
||||
# Format the combined context
|
||||
return f"""
|
||||
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
|
||||
{combined_sources}
|
||||
"""
|
||||
|
||||
|
||||
async def naive_query(
|
||||
query,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
@@ -996,8 +1067,16 @@ async def naive_query(
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
|
||||
if len(response)>len(sys_prompt):
|
||||
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response[len(sys_prompt) :]
|
||||
.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
|
||||
|
||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
||||
|
||||
PROMPTS[
|
||||
"entity_extraction"
|
||||
] = """-Goal-
|
||||
PROMPTS["entity_extraction"] = """-Goal-
|
||||
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
||||
|
||||
-Steps-
|
||||
@@ -146,9 +144,7 @@ PROMPTS[
|
||||
|
||||
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
||||
|
||||
PROMPTS[
|
||||
"rag_response"
|
||||
] = """---Role---
|
||||
PROMPTS["rag_response"] = """---Role---
|
||||
|
||||
You are a helpful assistant responding to questions about data in the tables provided.
|
||||
|
||||
@@ -241,9 +237,7 @@ Output:
|
||||
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"naive_rag_response"
|
||||
] = """You're a helpful assistant
|
||||
PROMPTS["naive_rag_response"] = """You're a helpful assistant
|
||||
Below are the knowledge you know:
|
||||
{content_data}
|
||||
---
|
||||
|
@@ -1,16 +1,11 @@
|
||||
import asyncio
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import pickle
|
||||
import hnswlib
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import xxhash
|
||||
|
||||
from .utils import load_json, logger, write_json
|
||||
from .base import (
|
||||
@@ -19,6 +14,7 @@ from .base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
|
||||
async def drop(self):
|
||||
self._data = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
||||
node_mapping = {
|
||||
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
||||
} # type: ignore
|
||||
graph = nx.relabel_nodes(graph, node_mapping)
|
||||
return NetworkXStorage._stabilize_graph(graph)
|
||||
|
||||
|
@@ -16,18 +16,22 @@ ENCODER = None
|
||||
|
||||
logger = logging.getLogger("lightrag")
|
||||
|
||||
|
||||
def set_logger(log_file: str):
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
if not logger.handlers:
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
@@ -37,6 +41,7 @@ class EmbeddingFunc:
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
"""Locate the JSON string body from a string"""
|
||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def convert_response_to_json(response: str) -> dict:
|
||||
json_str = locate_json_string_body_from_string(response)
|
||||
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
||||
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
|
||||
logger.error(f"Failed to parse JSON: {json_str}")
|
||||
raise e from None
|
||||
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
|
||||
|
||||
def compute_mdhash_id(content, prefix: str = ""):
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||
"""Add restriction of maximum async calling times for a async func"""
|
||||
|
||||
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||
|
||||
return final_decro
|
||||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Wrap a function with attributes"""
|
||||
|
||||
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
||||
|
||||
return final_decro
|
||||
|
||||
|
||||
def load_json(file_name):
|
||||
if not os.path.exists(file_name):
|
||||
return None
|
||||
with open(file_name, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(json_obj, file_name):
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
||||
global ENCODER
|
||||
if ENCODER is None:
|
||||
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
||||
content = ENCODER.decode(tokens)
|
||||
return content
|
||||
|
||||
|
||||
def pack_user_ass_to_openai_messages(*args: str):
|
||||
roles = ["user", "assistant"]
|
||||
return [
|
||||
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
||||
]
|
||||
|
||||
|
||||
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
||||
"""Split a string by multiple markers"""
|
||||
if not markers:
|
||||
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
|
||||
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
||||
return [r.strip() for r in results if r.strip()]
|
||||
|
||||
|
||||
# Refer the utils functions of the official GraphRAG implementation:
|
||||
# https://github.com/microsoft/graphrag
|
||||
def clean_str(input: Any) -> str:
|
||||
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
|
||||
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||
|
||||
|
||||
def is_float_regex(value):
|
||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||
|
||||
|
||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
||||
"""Truncate a list of data by token size"""
|
||||
if max_token_size <= 0:
|
||||
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
||||
return list_data[:i]
|
||||
return list_data
|
||||
|
||||
|
||||
def list_of_list_to_csv(data: list[list]):
|
||||
return "\n".join(
|
||||
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
||||
)
|
||||
|
||||
|
||||
def save_data_to_file(data, file_name):
|
||||
with open(file_name, 'w', encoding='utf-8') as f:
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
@@ -3,11 +3,11 @@ import json
|
||||
import glob
|
||||
import argparse
|
||||
|
||||
def extract_unique_contexts(input_directory, output_directory):
|
||||
|
||||
def extract_unique_contexts(input_directory, output_directory):
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
|
||||
jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
|
||||
print(f"Found {len(jsonl_files)} JSONL files.")
|
||||
|
||||
for file_path in jsonl_files:
|
||||
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
|
||||
print(f"Processing file: {filename}")
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as infile:
|
||||
with open(file_path, "r", encoding="utf-8") as infile:
|
||||
for line_number, line in enumerate(infile, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
context = json_obj.get('context')
|
||||
context = json_obj.get("context")
|
||||
if context and context not in unique_contexts_dict:
|
||||
unique_contexts_dict[context] = None
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
|
||||
print(
|
||||
f"JSON decoding error in file {filename} at line {line_number}: {e}"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {filename}")
|
||||
continue
|
||||
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
|
||||
continue
|
||||
|
||||
unique_contexts_list = list(unique_contexts_dict.keys())
|
||||
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
|
||||
print(
|
||||
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
|
||||
)
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as outfile:
|
||||
with open(output_path, "w", encoding="utf-8") as outfile:
|
||||
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
||||
print(f"Unique `context` entries have been saved to: {output_filename}")
|
||||
except Exception as e:
|
||||
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
|
||||
parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@@ -4,8 +4,9 @@ import time
|
||||
|
||||
from lightrag import LightRAG
|
||||
|
||||
|
||||
def insert_text(rag, file_path):
|
||||
with open(file_path, mode='r') as f:
|
||||
with open(file_path, mode="r") as f:
|
||||
unique_contexts = json.load(f)
|
||||
|
||||
retries = 0
|
||||
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
|
||||
if retries == max_retries:
|
||||
print("Insertion failed after exceeding the maximum number of retries")
|
||||
|
||||
|
||||
cls = "agriculture"
|
||||
WORKING_DIR = "../{cls}"
|
||||
|
||||
|
@@ -7,6 +7,7 @@ from lightrag import LightRAG
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||
|
||||
|
||||
## For Upstage API
|
||||
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||
async def llm_model_func(
|
||||
@@ -19,20 +20,24 @@ async def llm_model_func(
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embedding(
|
||||
texts,
|
||||
model="solar-embedding-1-large-query",
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar"
|
||||
base_url="https://api.upstage.ai/v1/solar",
|
||||
)
|
||||
|
||||
|
||||
## /For Upstage API
|
||||
|
||||
|
||||
def insert_text(rag, file_path):
|
||||
with open(file_path, mode='r') as f:
|
||||
with open(file_path, mode="r") as f:
|
||||
unique_contexts = json.load(f)
|
||||
|
||||
retries = 0
|
||||
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
|
||||
if retries == max_retries:
|
||||
print("Insertion failed after exceeding the maximum number of retries")
|
||||
|
||||
|
||||
cls = "mix"
|
||||
WORKING_DIR = f"../{cls}"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
rag = LightRAG(working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=4096,
|
||||
max_token_size=8192,
|
||||
func=embedding_func
|
||||
)
|
||||
)
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||
),
|
||||
)
|
||||
|
||||
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
from openai import OpenAI
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
|
||||
def openai_complete_if_cache(
|
||||
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -19,14 +19,16 @@ def openai_complete_if_cache(
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
def get_summary(context, tot_tokens=2000):
|
||||
tokens = tokenizer.tokenize(context)
|
||||
half_tokens = tot_tokens // 2
|
||||
|
||||
start_tokens = tokens[1000:1000 + half_tokens]
|
||||
end_tokens = tokens[-(1000 + half_tokens):1000]
|
||||
start_tokens = tokens[1000 : 1000 + half_tokens]
|
||||
end_tokens = tokens[-(1000 + half_tokens) : 1000]
|
||||
|
||||
summary_tokens = start_tokens + end_tokens
|
||||
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
||||
@@ -34,9 +36,9 @@ def get_summary(context, tot_tokens=2000):
|
||||
return summary
|
||||
|
||||
|
||||
clses = ['agriculture']
|
||||
clses = ["agriculture"]
|
||||
for cls in clses:
|
||||
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
|
||||
with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
|
||||
unique_contexts = json.load(f)
|
||||
|
||||
summaries = [get_summary(context) for context in unique_contexts]
|
||||
@@ -67,7 +69,7 @@ for cls in clses:
|
||||
...
|
||||
"""
|
||||
|
||||
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
|
||||
result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
|
||||
|
||||
file_path = f"../datasets/questions/{cls}_questions.txt"
|
||||
with open(file_path, "w") as file:
|
||||
|
@@ -4,16 +4,18 @@ import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def extract_queries(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
data = f.read()
|
||||
|
||||
data = data.replace('**', '')
|
||||
data = data.replace("**", "")
|
||||
|
||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
||||
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
async def process_query(query_text, rag_instance, query_param):
|
||||
try:
|
||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
|
||||
except Exception as e:
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
|
||||
|
||||
def run_queries_and_save_to_json(
|
||||
queries, rag_instance, query_param, output_file, error_file
|
||||
):
|
||||
loop = always_get_an_event_loop()
|
||||
|
||||
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
|
||||
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
||||
error_file, "a", encoding="utf-8"
|
||||
) as err_file:
|
||||
result_file.write("[\n")
|
||||
first_entry = True
|
||||
|
||||
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
|
||||
result, error = loop.run_until_complete(
|
||||
process_query(query_text, rag_instance, query_param)
|
||||
)
|
||||
|
||||
if result:
|
||||
if not first_entry:
|
||||
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
||||
|
||||
result_file.write("\n]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cls = "agriculture"
|
||||
mode = "hybrid"
|
||||
@@ -59,4 +70,6 @@ if __name__ == "__main__":
|
||||
query_param = QueryParam(mode=mode)
|
||||
|
||||
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
||||
run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
|
||||
run_queries_and_save_to_json(
|
||||
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
|
||||
)
|
||||
|
@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
|
||||
|
||||
## For Upstage API
|
||||
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||
async def llm_model_func(
|
||||
@@ -20,28 +21,33 @@ async def llm_model_func(
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embedding(
|
||||
texts,
|
||||
model="solar-embedding-1-large-query",
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar"
|
||||
base_url="https://api.upstage.ai/v1/solar",
|
||||
)
|
||||
|
||||
|
||||
## /For Upstage API
|
||||
|
||||
|
||||
def extract_queries(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
data = f.read()
|
||||
|
||||
data = data.replace('**', '')
|
||||
data = data.replace("**", "")
|
||||
|
||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
||||
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
async def process_query(query_text, rag_instance, query_param):
|
||||
try:
|
||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
|
||||
except Exception as e:
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
|
||||
|
||||
def run_queries_and_save_to_json(
|
||||
queries, rag_instance, query_param, output_file, error_file
|
||||
):
|
||||
loop = always_get_an_event_loop()
|
||||
|
||||
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
|
||||
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
||||
error_file, "a", encoding="utf-8"
|
||||
) as err_file:
|
||||
result_file.write("[\n")
|
||||
first_entry = True
|
||||
|
||||
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
|
||||
result, error = loop.run_until_complete(
|
||||
process_query(query_text, rag_instance, query_param)
|
||||
)
|
||||
|
||||
if result:
|
||||
if not first_entry:
|
||||
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
||||
|
||||
result_file.write("\n]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cls = "mix"
|
||||
mode = "hybrid"
|
||||
WORKING_DIR = f"../{cls}"
|
||||
|
||||
rag = LightRAG(working_dir=WORKING_DIR)
|
||||
rag = LightRAG(working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=4096,
|
||||
max_token_size=8192,
|
||||
func=embedding_func
|
||||
)
|
||||
)
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||
),
|
||||
)
|
||||
query_param = QueryParam(mode=mode)
|
||||
|
||||
base_dir='../datasets/questions'
|
||||
base_dir = "../datasets/questions"
|
||||
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
||||
run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
|
||||
run_queries_and_save_to_json(
|
||||
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
|
||||
)
|
||||
|
@@ -1,13 +1,13 @@
|
||||
aioboto3
|
||||
openai
|
||||
tiktoken
|
||||
networkx
|
||||
graspologic
|
||||
nano-vectordb
|
||||
hnswlib
|
||||
xxhash
|
||||
tenacity
|
||||
transformers
|
||||
torch
|
||||
ollama
|
||||
accelerate
|
||||
aioboto3
|
||||
graspologic
|
||||
hnswlib
|
||||
nano-vectordb
|
||||
networkx
|
||||
ollama
|
||||
openai
|
||||
tenacity
|
||||
tiktoken
|
||||
torch
|
||||
transformers
|
||||
xxhash
|
||||
|
Reference in New Issue
Block a user