Optimization logic

This commit is contained in:
jin
2024-11-25 13:29:55 +08:00
parent d6589684ef
commit af3aef5d88
10 changed files with 342 additions and 423 deletions

2
.gitignore vendored
View File

@@ -12,3 +12,5 @@ ignore_this.txt
.venv/ .venv/
*.ignore.* *.ignore.*
.ruff_cache/ .ruff_cache/
gui/
*.log

View File

@@ -1,11 +1,16 @@
from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional,Any
from fastapi.responses import JSONResponse
import sys import sys, os
import os print(os.getcwd())
from pathlib import Path from pathlib import Path
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
import asyncio import asyncio
import nest_asyncio import nest_asyncio
@@ -13,15 +18,11 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
import numpy as np import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
# Apply nest_asyncio to solve event loop issues # Apply nest_asyncio to solve event loop issues
nest_asyncio.apply() nest_asyncio.apply()
@@ -37,18 +38,16 @@ APIKEY = "ocigenerativeai"
# Configure working directory # Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}") print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus") LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
print(f"LLM_MODEL: {LLM_MODEL}") print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0") EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512)) EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -78,10 +77,10 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1] embedding_dim = embedding.shape[1]
return embedding_dim return embedding_dim
async def init(): async def init():
# Detect embedding dimension # Detect embedding dimension
embedding_dimension = await get_embedding_dim() embedding_dimension = 1024 #await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}") print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection # Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB # The `config` parameter is the connection configuration of Oracle DB
@@ -89,36 +88,36 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={ oracle_db = OracleDB(config={
"user": "", "user":"",
"password": "", "password":"",
"dsn": "", "dsn":"",
"config_dir": "", "config_dir":"path_to_config_dir",
"wallet_location": "", "wallet_location":"path_to_wallet_location",
"wallet_password": "", "wallet_password":"wallet_password",
"workspace": "", "workspace":"company"
} # specify which docs you want to store and query } # specify which docs you want to store and query
) )
# Check if Oracle DB tables exist, if not, tables will be created # Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables() await oracle_db.check_tables()
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
chunk_token_size=512, chunk_token_size=512,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension, embedding_dim=embedding_dimension,
max_token_size=512, max_token_size=512,
func=embedding_func, func=embedding_func,
), ),
graph_storage="OracleGraphStorage", graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage", kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage", vector_storage="OracleVectorDBStorage"
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db rag.graph_storage_cls.db = oracle_db
@@ -128,6 +127,17 @@ async def init():
return rag return rag
# Extract and Insert into LightRAG storage
#with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform search in different modes
# modes = ["naive", "local", "global", "hybrid"]
# for mode in modes:
# print("="*20, mode, "="*20)
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
# print("-"*100, "\n")
# Data models # Data models
@@ -135,7 +145,10 @@ class QueryRequest(BaseModel):
query: str query: str
mode: str = "hybrid" mode: str = "hybrid"
only_need_context: bool = False only_need_context: bool = False
only_need_prompt: bool = False
class DataRequest(BaseModel):
limit: int = 100
class InsertRequest(BaseModel): class InsertRequest(BaseModel):
text: str text: str
@@ -143,7 +156,7 @@ class InsertRequest(BaseModel):
class Response(BaseModel): class Response(BaseModel):
status: str status: str
data: Optional[str] = None data: Optional[Any] = None
message: Optional[str] = None message: Optional[str] = None
@@ -151,7 +164,6 @@ class Response(BaseModel):
rag = None # 定义为全局对象 rag = None # 定义为全局对象
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global rag global rag
@@ -160,24 +172,39 @@ async def lifespan(app: FastAPI):
yield yield
app = FastAPI( app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
@app.post("/query", response_model=Response) @app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest): async def query_endpoint(request: QueryRequest):
try: #try:
# loop = asyncio.get_event_loop() # loop = asyncio.get_event_loop()
result = await rag.aquery( if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery(
request.query, request.query,
param=QueryParam( param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k
), ),
) )
return Response(status="success", data=result) return Response(status="success", data=result)
except Exception as e: # except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) # raise HTTPException(status_code=500, detail=str(e))
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@app.post("/insert", response_model=Response) @app.post("/insert", response_model=Response)
@@ -220,7 +247,7 @@ async def health_check():
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020) uvicorn.run(app, host="127.0.0.1", port=8020)
# Usage example # Usage example
# To run the server, use the following command in your terminal: # To run the server, use the following command in your terminal:

View File

@@ -97,6 +97,8 @@ async def main():
graph_storage="OracleGraphStorage", graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage", kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage", vector_storage="OracleVectorDBStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"},
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool

View File

@@ -21,6 +21,8 @@ class QueryParam:
response_type: str = "Multiple Paragraphs" response_type: str = "Multiple Paragraphs"
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60 top_k: int = 60
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks. # Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000 max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions # Number of tokens for the relationship descriptions

View File

@@ -333,6 +333,8 @@ class OracleGraphStorage(BaseGraphStorage):
entity_type = node_data["entity_type"] entity_type = node_data["entity_type"]
description = node_data["description"] description = node_data["description"]
source_id = node_data["source_id"] source_id = node_data["source_id"]
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
content = entity_name + description content = entity_name + description
contents = [content] contents = [content]
batches = [ batches = [
@@ -369,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"] keywords = edge_data["keywords"]
description = edge_data["description"] description = edge_data["description"]
source_chunk_id = edge_data["source_id"] source_chunk_id = edge_data["source_id"]
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
content = keywords + source_name + target_name + description content = keywords + source_name + target_name + description
contents = [content] contents = [content]
batches = [ batches = [
@@ -544,6 +548,14 @@ class OracleGraphStorage(BaseGraphStorage):
res = await self.db.query(sql=SQL,params=params, multirows=True) res = await self.db.query(sql=SQL,params=params, multirows=True)
if res: if res:
return res return res
async def get_statistics(self):
SQL = SQL_TEMPLATES["get_statistics"]
params = {"workspace":self.db.workspace}
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
N_T = { N_T = {
"full_docs": "LIGHTRAG_DOC_FULL", "full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS", "text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -715,18 +727,36 @@ SQL_TEMPLATES = {
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
"get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content "get_all_nodes":"""WITH t0 AS (
FROM LIGHTRAG_GRAPH_NODES t1 SELECT name AS id, entity_type AS label, entity_type, description,
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
WHERE t1.workspace=:workspace FROM lightrag_graph_nodes
order by t1.CREATETIME DESC WHERE workspace = :workspace
fetch first :limit rows only ORDER BY createtime DESC fetch first :limit rows only
""", ), t1 AS (
SELECT t0.id, source_chunk_id
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
), t2 AS (
SELECT t1.id, LISTAGG(t2.content, '\n') content
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
GROUP BY t1.id
)
SELECT t0.id, label, entity_type, description, t2.content
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
t1.weight,t1.DESCRIPTION,t2.content t1.weight,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_EDGES t1 FROM LIGHTRAG_GRAPH_EDGES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC order by t1.CREATETIME DESC
fetch first :limit rows only""" fetch first :limit rows only""",
"get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
FROM (
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
UNION
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
)""",
} }

View File

@@ -12,9 +12,8 @@ from .llm import (
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities, extract_entities,
local_query, # local_query,global_query,hybrid_query,
global_query, kg_query,
hybrid_query,
naive_query, naive_query,
) )
@@ -309,28 +308,8 @@ class LightRAG:
return loop.run_until_complete(self.aquery(query, param)) return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()): async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local": if param.mode in ["local", "global", "hybrid"]:
response = await local_query( response = await kg_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "global":
response = await global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hybrid":
response = await hybrid_query(
query, query,
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,

View File

@@ -69,12 +69,15 @@ async def openai_complete_if_cache(
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
content = response.choices[0].message.content
if r'\u' in content:
content = content.encode('utf-8').decode('unicode_escape')
print(content)
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert( await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}} {args_hash: {"return": response.choices[0].message.content, "model": model}}
) )
return response.choices[0].message.content return content
@retry( @retry(
@@ -539,7 +542,7 @@ async def openai_embedding(
texts: list[str], texts: list[str],
model: str = "text-embedding-3-small", model: str = "text-embedding-3-small",
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None
) -> np.ndarray: ) -> np.ndarray:
if api_key: if api_key:
os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_API_KEY"] = api_key
@@ -548,7 +551,7 @@ async def openai_embedding(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
) )
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"
) )
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])

View File

@@ -248,6 +248,13 @@ async def extract_entities(
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items()) ordered_chunks = list(chunks.items())
# add language and example number params to prompt
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
else:
examples="\n".join(PROMPTS["entity_extraction_examples"])
entity_extract_prompt = PROMPTS["entity_extraction"] entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict( context_base = dict(
@@ -255,7 +262,9 @@ async def extract_entities(
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]), entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
) examples=examples,
language=language)
continue_prompt = PROMPTS["entiti_continue_extraction"] continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
@@ -270,7 +279,6 @@ async def extract_entities(
content = chunk_dp["content"] content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt) final_result = await use_llm_func(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning): for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history) glean_result = await use_llm_func(continue_prompt, history_messages=history)
@@ -388,8 +396,7 @@ async def extract_entities(
return knowledge_graph_inst return knowledge_graph_inst
async def kg_query(
async def local_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
@@ -399,43 +406,61 @@ async def local_query(
global_config: dict, global_config: dict,
) -> str: ) -> str:
context = None context = None
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
else:
examples="\n".join(PROMPTS["keywords_extraction_examples"])
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query) kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result) logger.info(f"kw_prompt result:")
logger.debug("local_query json_text:", json_text) print(result)
try: try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text) keywords_data = json.loads(json_text)
keywords = keywords_data.get("low_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords) ll_keywords = keywords_data.get("low_level_keywords", [])
except json.JSONDecodeError:
print(result)
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result) # Handle parsing error
keywords = keywords_data.get("low_level_keywords", []) except json.JSONDecodeError as e:
keywords = ", ".join(keywords) print(f"JSON parsing error: {e} {result}")
# Handle parsing error return PROMPTS["fail_response"]
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}") # Handdle keywords missing
return PROMPTS["fail_response"] if hl_keywords == [] and ll_keywords == []:
if keywords: logger.warning("low_level_keywords and high_level_keywords is empty")
context = await _build_local_query_context( return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
# Build context
keywords = [ll_keywords, hl_keywords]
context = await _build_query_context(
keywords, keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
if query_param.only_need_context: if query_param.only_need_context:
return context return context
if context is None: if context is None:
@@ -443,13 +468,13 @@ async def local_query(
sys_prompt_temp = PROMPTS["rag_response"] sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type context_data=context, response_type=query_param.response_type
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
) )
if len(response) > len(sys_prompt): if len(response) > len(sys_prompt):
response = ( response = (
response.replace(sys_prompt, "") response.replace(sys_prompt, "")
@@ -464,22 +489,87 @@ async def local_query(
return response return response
async def _build_local_query_context( async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
query_param.mode = "global"
else:
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
query_param.mode = "local"
else:
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param
)
if query_param.mode == 'hybrid':
entities_context,relations_context,text_units_context = combine_contexts(
[hl_entities_context,ll_entities_context],
[hl_relations_context,ll_relations_context],
[hl_text_units_context,ll_text_units_context]
)
elif query_param.mode == 'local':
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
elif query_param.mode == 'global':
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
return f"""
# -----Entities-----
# ```csv
# {entities_context}
# ```
# -----Relationships-----
# ```csv
# {relations_context}
# ```
# -----Sources-----
# ```csv
# {text_units_context}
# ```
# """
async def _get_node_data(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
# 获取相似的实体
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
return None return None
# 获取实体信息
node_datas = await asyncio.gather( node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
) )
if not all([n is not None for n in node_datas]): if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged") logger.warning("Some nodes are missing, maybe the storage is damaged")
# 获取实体的度
node_degrees = await asyncio.gather( node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
) )
@@ -488,15 +578,19 @@ async def _build_local_query_context(
for k, n, d in zip(results, node_datas, node_degrees) for k, n, d in zip(results, node_datas, node_degrees)
if n is not None if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# 根据实体获取文本片段
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst node_datas, query_param, text_chunks_db, knowledge_graph_inst
) )
# 获取关联的边
use_relations = await _find_most_related_edges_from_entities( use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst node_datas, query_param, knowledge_graph_inst
) )
logger.info( logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
) )
# 构建提示词
entites_section_list = [["id", "entity", "type", "description", "rank"]] entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas): for i, n in enumerate(node_datas):
entites_section_list.append( entites_section_list.append(
@@ -531,20 +625,7 @@ async def _build_local_query_context(
for i, t in enumerate(use_text_units): for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]]) text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list) text_units_context = list_of_list_to_csv(text_units_section_list)
return f""" return entities_context,relations_context,text_units_context
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
@@ -659,88 +740,9 @@ async def _find_most_related_edges_from_entities(
return all_edges_data return all_edges_data
async def global_query( async def _get_edge_data(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("global json_text:", json_text)
try:
keywords_data = json.loads(json_text)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if keywords:
context = await _build_global_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_global_query_context(
keywords, keywords,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
@@ -782,6 +784,7 @@ async def _build_global_query_context(
logger.info( logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
) )
relations_section_list = [ relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"] ["id", "source", "target", "description", "keywords", "weight", "rank"]
] ]
@@ -816,21 +819,8 @@ async def _build_global_query_context(
for i, t in enumerate(use_text_units): for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]]) text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list) text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_entities_from_relationships( async def _find_most_related_entities_from_relationships(
@@ -901,137 +891,11 @@ async def _find_related_text_unit_from_relationships(
return all_text_units return all_text_units
async def hybrid_query( def combine_contexts(entities, relationships, sources):
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
logger.debug("hybrid_query json_text:", json_text)
try:
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if ll_keywords:
low_level_context = await _build_local_query_context(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings # Function to extract entities, relationships, and sources from context strings
hl_entities, ll_entities = entities[0], entities[1]
def extract_sections(context): hl_relationships, ll_relationships = relationships[0],relationships[1]
entities_match = re.search( hl_sources, ll_sources = sources[0], sources[1]
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# Extract sections from both contexts
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities # Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities) combined_entities = process_combine_contexts(hl_entities, ll_entities)
@@ -1043,21 +907,7 @@ def combine_contexts(high_level_context, low_level_context):
# Combine and deduplicate the sources # Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources) combined_sources = process_combine_contexts(hl_sources, ll_sources)
# Format the combined context return combined_entities, combined_relationships, combined_sources
return f"""
-----Entities-----
```csv
{combined_entities}
```
-----Relationships-----
```csv
{combined_relationships}
```
-----Sources-----
```csv
{combined_sources}
```
"""
async def naive_query( async def naive_query(
@@ -1080,7 +930,7 @@ async def naive_query(
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context: if query_param.only_need_context:
return section return section
sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt_temp = PROMPTS["naive_rag_response"]

View File

@@ -2,6 +2,7 @@ GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {} PROMPTS = {}
PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
@@ -11,6 +12,7 @@ PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
PROMPTS["entity_extraction"] = """-Goal- PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
Use {language} as output language.
-Steps- -Steps-
1. Identify all entities. For each identified entity, extract the following information: 1. Identify all entities. For each identified entity, extract the following information:
@@ -38,7 +40,19 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_
###################### ######################
-Examples- -Examples-
###################### ######################
Example 1: {examples}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
Entity_types: [person, technology, mission, organization, location] Entity_types: [person, technology, mission, organization, location]
Text: Text:
@@ -62,8 +76,8 @@ Output:
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter} ("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
############################# #############################""",
Example 2: """Example 2:
Entity_types: [person, technology, mission, organization, location] Entity_types: [person, technology, mission, organization, location]
Text: Text:
@@ -80,8 +94,8 @@ Output:
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter} ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter} ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter} ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
############################# #############################""",
Example 3: """Example 3:
Entity_types: [person, role, technology, organization, event, location, concept] Entity_types: [person, role, technology, organization, event, location, concept]
Text: Text:
@@ -107,22 +121,15 @@ Output:
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter} ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
############################# #############################"""
-Real Data- ]
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
PROMPTS[ PROMPTS["summarize_entity_descriptions"] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
"summarize_entity_descriptions"
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context. Make sure it is written in third person, and include the entity names so we the have full context.
Use Chinese as output language.
####### #######
-Data- -Data-
@@ -132,14 +139,10 @@ Description List: {description_list}
Output: Output:
""" """
PROMPTS[ PROMPTS["entiti_continue_extraction"] = """MANY entities were missed in the last extraction. Add them below using the same format:
"entiti_continue_extraction"
] = """MANY entities were missed in the last extraction. Add them below using the same format:
""" """
PROMPTS[ PROMPTS["entiti_if_loop_extraction"] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
"entiti_if_loop_extraction"
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
""" """
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
@@ -169,6 +172,7 @@ Add sections and commentary to the response as appropriate for the length and fo
PROMPTS["keywords_extraction"] = """---Role--- PROMPTS["keywords_extraction"] = """---Role---
You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query. You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
Use Chinese as output language.
---Goal--- ---Goal---
@@ -184,7 +188,20 @@ Given the query, list both high-level and low-level keywords. High-level keyword
###################### ######################
-Examples- -Examples-
###################### ######################
Example 1: {examples}
#############################
-Real Data-
######################
Query: {query}
######################
The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
Output:
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
Query: "How does international trade influence global economic stability?" Query: "How does international trade influence global economic stability?"
################ ################
@@ -193,8 +210,8 @@ Output:
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"], "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}} }}
############################# #############################""",
Example 2: """Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?" Query: "What are the environmental consequences of deforestation on biodiversity?"
################ ################
@@ -203,8 +220,8 @@ Output:
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}} }}
############################# #############################""",
Example 3: """Example 3:
Query: "What is the role of education in reducing poverty?" Query: "What is the role of education in reducing poverty?"
################ ################
@@ -213,14 +230,9 @@ Output:
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}} }}
############################# #############################"""
-Real Data- ]
######################
Query: {query}
######################
Output:
"""
PROMPTS["naive_rag_response"] = """---Role--- PROMPTS["naive_rag_response"] = """---Role---

View File

@@ -47,14 +47,26 @@ class EmbeddingFunc:
def locate_json_string_body_from_string(content: str) -> Union[str, None]: def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string""" """Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL) try:
if maybe_json_str is not None: maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
maybe_json_str = maybe_json_str.group(0) if maybe_json_str is not None:
maybe_json_str = maybe_json_str.replace("\\n", "") maybe_json_str = maybe_json_str.group(0)
maybe_json_str = maybe_json_str.replace("\n", "") maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("'", '"') maybe_json_str = maybe_json_str.replace("\n", "")
return maybe_json_str maybe_json_str = maybe_json_str.replace("'", '"')
else: json.loads(maybe_json_str)
return maybe_json_str
except:
# try:
# content = (
# content.replace(kw_prompt[:-1], "")
# .replace("user", "")
# .replace("model", "")
# .strip()
# )
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
# json.loads(maybe_json_str)
return None return None