Optimization logic

This commit is contained in:
jin
2024-11-25 13:29:55 +08:00
parent 662303f605
commit 89c2de54a2
10 changed files with 342 additions and 423 deletions

2
.gitignore vendored
View File

@@ -12,3 +12,5 @@ ignore_this.txt
.venv/
*.ignore.*
.ruff_cache/
gui/
*.log

View File

@@ -1,11 +1,16 @@
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager
from pydantic import BaseModel
from typing import Optional
from typing import Optional,Any
from fastapi.responses import JSONResponse
import sys
import os
import sys, os
print(os.getcwd())
from pathlib import Path
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
import asyncio
import nest_asyncio
@@ -13,15 +18,11 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
@@ -37,18 +38,16 @@ APIKEY = "ocigenerativeai"
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus")
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -78,10 +77,10 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1]
return embedding_dim
async def init():
# Detect embedding dimension
embedding_dimension = await get_embedding_dim()
embedding_dimension = 1024 #await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
@@ -89,36 +88,36 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "",
"wallet_location": "",
"wallet_password": "",
"workspace": "",
} # specify which docs you want to store and query
)
oracle_db = OracleDB(config={
"user":"",
"password":"",
"dsn":"",
"config_dir":"path_to_config_dir",
"wallet_location":"path_to_wallet_location",
"wallet_password":"wallet_password",
"workspace":"company"
} # specify which docs you want to store and query
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage"
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
@@ -128,6 +127,17 @@ async def init():
return rag
# Extract and Insert into LightRAG storage
#with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform search in different modes
# modes = ["naive", "local", "global", "hybrid"]
# for mode in modes:
# print("="*20, mode, "="*20)
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
# print("-"*100, "\n")
# Data models
@@ -135,7 +145,10 @@ class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
only_need_context: bool = False
only_need_prompt: bool = False
class DataRequest(BaseModel):
limit: int = 100
class InsertRequest(BaseModel):
text: str
@@ -143,7 +156,7 @@ class InsertRequest(BaseModel):
class Response(BaseModel):
status: str
data: Optional[str] = None
data: Optional[Any] = None
message: Optional[str] = None
@@ -151,7 +164,6 @@ class Response(BaseModel):
rag = None # 定义为全局对象
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
@@ -160,24 +172,39 @@ async def lifespan(app: FastAPI):
yield
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
#try:
# loop = asyncio.get_event_loop()
result = await rag.aquery(
if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k
),
)
return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return Response(status="success", data=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@app.post("/insert", response_model=Response)
@@ -220,7 +247,7 @@ async def health_check():
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)
uvicorn.run(app, host="127.0.0.1", port=8020)
# Usage example
# To run the server, use the following command in your terminal:
@@ -237,4 +264,4 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -97,6 +97,8 @@ async def main():
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool

View File

@@ -21,6 +21,8 @@ class QueryParam:
response_type: str = "Multiple Paragraphs"
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions

View File

@@ -333,6 +333,8 @@ class OracleGraphStorage(BaseGraphStorage):
entity_type = node_data["entity_type"]
description = node_data["description"]
source_id = node_data["source_id"]
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
content = entity_name + description
contents = [content]
batches = [
@@ -369,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
content = keywords + source_name + target_name + description
contents = [content]
batches = [
@@ -544,6 +548,14 @@ class OracleGraphStorage(BaseGraphStorage):
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
async def get_statistics(self):
SQL = SQL_TEMPLATES["get_statistics"]
params = {"workspace":self.db.workspace}
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -715,18 +727,36 @@ SQL_TEMPLATES = {
WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
"get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_NODES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only
""",
"get_all_nodes":"""WITH t0 AS (
SELECT name AS id, entity_type AS label, entity_type, description,
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
FROM lightrag_graph_nodes
WHERE workspace = :workspace
ORDER BY createtime DESC fetch first :limit rows only
), t1 AS (
SELECT t0.id, source_chunk_id
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
), t2 AS (
SELECT t1.id, LISTAGG(t2.content, '\n') content
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
GROUP BY t1.id
)
SELECT t0.id, label, entity_type, description, t2.content
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
t1.weight,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_EDGES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only"""
fetch first :limit rows only""",
"get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
FROM (
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
UNION
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
)""",
}

View File

@@ -12,9 +12,8 @@ from .llm import (
from .operate import (
chunking_by_token_size,
extract_entities,
local_query,
global_query,
hybrid_query,
# local_query,global_query,hybrid_query,
kg_query,
naive_query,
)
@@ -309,28 +308,8 @@ class LightRAG:
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local":
response = await local_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "global":
response = await global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hybrid":
response = await hybrid_query(
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,

View File

@@ -69,12 +69,15 @@ async def openai_complete_if_cache(
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
if r'\u' in content:
content = content.encode('utf-8').decode('unicode_escape')
print(content)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
return response.choices[0].message.content
return content
@retry(
@@ -539,7 +542,7 @@ async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
api_key: str = None
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
@@ -548,7 +551,7 @@ async def openai_embedding(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

View File

@@ -248,14 +248,23 @@ async def extract_entities(
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
else:
examples="\n".join(PROMPTS["entity_extraction_examples"])
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
)
examples=examples,
language=language)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
@@ -270,7 +279,6 @@ async def extract_entities(
content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
@@ -388,8 +396,7 @@ async def extract_entities(
return knowledge_graph_inst
async def local_query(
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
@@ -399,43 +406,61 @@ async def local_query(
global_config: dict,
) -> str:
context = None
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
else:
examples="\n".join(PROMPTS["keywords_extraction_examples"])
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("local_query json_text:", json_text)
kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
result = await use_model_func(kw_prompt)
logger.info(f"kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
print(result)
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if keywords:
context = await _build_local_query_context(
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"]
# Handdle keywords missing
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
# Build context
keywords = [ll_keywords, hl_keywords]
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
@@ -443,13 +468,13 @@ async def local_query(
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -464,22 +489,87 @@ async def local_query(
return response
async def _build_local_query_context(
async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
query_param.mode = "global"
else:
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
query_param.mode = "local"
else:
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param
)
if query_param.mode == 'hybrid':
entities_context,relations_context,text_units_context = combine_contexts(
[hl_entities_context,ll_entities_context],
[hl_relations_context,ll_relations_context],
[hl_text_units_context,ll_text_units_context]
)
elif query_param.mode == 'local':
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
elif query_param.mode == 'global':
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
return f"""
# -----Entities-----
# ```csv
# {entities_context}
# ```
# -----Relationships-----
# ```csv
# {relations_context}
# ```
# -----Sources-----
# ```csv
# {text_units_context}
# ```
# """
async def _get_node_data(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
# 获取相似的实体
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return None
# 获取实体信息
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
# 获取实体的度
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
@@ -488,15 +578,19 @@ async def _build_local_query_context(
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# 根据实体获取文本片段
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
# 获取关联的边
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
)
)
# 构建提示词
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
@@ -531,20 +625,7 @@ async def _build_local_query_context(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
return entities_context,relations_context,text_units_context
async def _find_most_related_text_unit_from_entities(
@@ -659,88 +740,9 @@ async def _find_most_related_edges_from_entities(
return all_edges_data
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("global json_text:", json_text)
try:
keywords_data = json.loads(json_text)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if keywords:
context = await _build_global_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_global_query_context(
async def _get_edge_data(
keywords,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
@@ -782,6 +784,7 @@ async def _build_global_query_context(
logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
)
relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"]
]
@@ -816,21 +819,8 @@ async def _build_global_query_context(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_entities_from_relationships(
@@ -901,137 +891,11 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
async def hybrid_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("hybrid_query json_text:", json_text)
try:
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if ll_keywords:
low_level_context = await _build_local_query_context(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
def combine_contexts(entities, relationships, sources):
# Function to extract entities, relationships, and sources from context strings
def extract_sections(context):
entities_match = re.search(
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# Extract sections from both contexts
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
hl_entities, ll_entities = entities[0], entities[1]
hl_relationships, ll_relationships = relationships[0],relationships[1]
hl_sources, ll_sources = sources[0], sources[1]
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)
@@ -1043,21 +907,7 @@ def combine_contexts(high_level_context, low_level_context):
# Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources)
# Format the combined context
return f"""
-----Entities-----
```csv
{combined_entities}
```
-----Relationships-----
```csv
{combined_relationships}
```
-----Sources-----
```csv
{combined_sources}
```
"""
return combined_entities, combined_relationships, combined_sources
async def naive_query(
@@ -1080,7 +930,7 @@ async def naive_query(
max_token_size=query_param.max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context:
return section
sys_prompt_temp = PROMPTS["naive_rag_response"]

View File

@@ -2,6 +2,7 @@ GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {}
PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
@@ -11,6 +12,7 @@ PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
Use {language} as output language.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
@@ -38,7 +40,19 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_
######################
-Examples-
######################
Example 1:
{examples}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
@@ -62,8 +76,8 @@ Output:
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
#############################
Example 2:
#############################""",
"""Example 2:
Entity_types: [person, technology, mission, organization, location]
Text:
@@ -80,8 +94,8 @@ Output:
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
#############################
Example 3:
#############################""",
"""Example 3:
Entity_types: [person, role, technology, organization, event, location, concept]
Text:
@@ -107,22 +121,15 @@ Output:
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
#############################"""
]
PROMPTS[
"summarize_entity_descriptions"
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
PROMPTS["summarize_entity_descriptions"] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.
Use Chinese as output language.
#######
-Data-
@@ -132,14 +139,10 @@ Description List: {description_list}
Output:
"""
PROMPTS[
"entiti_continue_extraction"
] = """MANY entities were missed in the last extraction. Add them below using the same format:
PROMPTS["entiti_continue_extraction"] = """MANY entities were missed in the last extraction. Add them below using the same format:
"""
PROMPTS[
"entiti_if_loop_extraction"
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
PROMPTS["entiti_if_loop_extraction"] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
"""
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
@@ -169,6 +172,7 @@ Add sections and commentary to the response as appropriate for the length and fo
PROMPTS["keywords_extraction"] = """---Role---
You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
Use Chinese as output language.
---Goal---
@@ -184,7 +188,20 @@ Given the query, list both high-level and low-level keywords. High-level keyword
######################
-Examples-
######################
Example 1:
{examples}
#############################
-Real Data-
######################
Query: {query}
######################
The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
Output:
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
Query: "How does international trade influence global economic stability?"
################
@@ -193,8 +210,8 @@ Output:
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}}
#############################
Example 2:
#############################""",
"""Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?"
################
@@ -203,8 +220,8 @@ Output:
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}}
#############################
Example 3:
#############################""",
"""Example 3:
Query: "What is the role of education in reducing poverty?"
################
@@ -213,14 +230,9 @@ Output:
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}}
#############################
-Real Data-
######################
Query: {query}
######################
Output:
#############################"""
]
"""
PROMPTS["naive_rag_response"] = """---Role---

View File

@@ -47,14 +47,26 @@ class EmbeddingFunc:
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
maybe_json_str = maybe_json_str.group(0)
maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
return maybe_json_str
else:
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
maybe_json_str = maybe_json_str.group(0)
maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str)
return maybe_json_str
except:
# try:
# content = (
# content.replace(kw_prompt[:-1], "")
# .replace("user", "")
# .replace("model", "")
# .strip()
# )
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
# json.loads(maybe_json_str)
return None