Merge pull request #325 from jin38324/main

Enhance Query Logic and Add Configurable Features
This commit is contained in:
zrguo
2024-11-27 18:44:01 +08:00
committed by GitHub
12 changed files with 396 additions and 390 deletions

2
.gitignore vendored
View File

@@ -12,3 +12,5 @@ ignore_this.txt
.venv/ .venv/
*.ignore.* *.ignore.*
.ruff_cache/ .ruff_cache/
gui/
*.log

View File

@@ -555,6 +555,35 @@ if __name__ == "__main__":
</details> </details>
### LightRAG init parameters
| **Parameter** | **Type** | **Explanation** | **Default** |
| --- | --- | --- | --- |
| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` |
| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` |
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
## API Server Implementation ## API Server Implementation
LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests. LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.

View File

@@ -1,5 +1,5 @@
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG
from lightrag.llm import gpt_4o_mini_complete from lightrag.llm import gpt_4o_mini_complete
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
@@ -24,50 +24,50 @@ custom_kg = {
"entity_name": "CompanyA", "entity_name": "CompanyA",
"entity_type": "Organization", "entity_type": "Organization",
"description": "A major technology company", "description": "A major technology company",
"source_id": "Source1" "source_id": "Source1",
}, },
{ {
"entity_name": "ProductX", "entity_name": "ProductX",
"entity_type": "Product", "entity_type": "Product",
"description": "A popular product developed by CompanyA", "description": "A popular product developed by CompanyA",
"source_id": "Source1" "source_id": "Source1",
}, },
{ {
"entity_name": "PersonA", "entity_name": "PersonA",
"entity_type": "Person", "entity_type": "Person",
"description": "A renowned researcher in AI", "description": "A renowned researcher in AI",
"source_id": "Source2" "source_id": "Source2",
}, },
{ {
"entity_name": "UniversityB", "entity_name": "UniversityB",
"entity_type": "Organization", "entity_type": "Organization",
"description": "A leading university specializing in technology and sciences", "description": "A leading university specializing in technology and sciences",
"source_id": "Source2" "source_id": "Source2",
}, },
{ {
"entity_name": "CityC", "entity_name": "CityC",
"entity_type": "Location", "entity_type": "Location",
"description": "A large metropolitan city known for its culture and economy", "description": "A large metropolitan city known for its culture and economy",
"source_id": "Source3" "source_id": "Source3",
}, },
{ {
"entity_name": "EventY", "entity_name": "EventY",
"entity_type": "Event", "entity_type": "Event",
"description": "An annual technology conference held in CityC", "description": "An annual technology conference held in CityC",
"source_id": "Source3" "source_id": "Source3",
}, },
{ {
"entity_name": "CompanyD", "entity_name": "CompanyD",
"entity_type": "Organization", "entity_type": "Organization",
"description": "A financial services company specializing in insurance", "description": "A financial services company specializing in insurance",
"source_id": "Source4" "source_id": "Source4",
}, },
{ {
"entity_name": "ServiceZ", "entity_name": "ServiceZ",
"entity_type": "Service", "entity_type": "Service",
"description": "An insurance product offered by CompanyD", "description": "An insurance product offered by CompanyD",
"source_id": "Source4" "source_id": "Source4",
} },
], ],
"relationships": [ "relationships": [
{ {
@@ -76,7 +76,7 @@ custom_kg = {
"description": "CompanyA develops ProductX", "description": "CompanyA develops ProductX",
"keywords": "develop, produce", "keywords": "develop, produce",
"weight": 1.0, "weight": 1.0,
"source_id": "Source1" "source_id": "Source1",
}, },
{ {
"src_id": "PersonA", "src_id": "PersonA",
@@ -84,7 +84,7 @@ custom_kg = {
"description": "PersonA works at UniversityB", "description": "PersonA works at UniversityB",
"keywords": "employment, affiliation", "keywords": "employment, affiliation",
"weight": 0.9, "weight": 0.9,
"source_id": "Source2" "source_id": "Source2",
}, },
{ {
"src_id": "CityC", "src_id": "CityC",
@@ -92,7 +92,7 @@ custom_kg = {
"description": "EventY is hosted in CityC", "description": "EventY is hosted in CityC",
"keywords": "host, location", "keywords": "host, location",
"weight": 0.8, "weight": 0.8,
"source_id": "Source3" "source_id": "Source3",
}, },
{ {
"src_id": "CompanyD", "src_id": "CompanyD",
@@ -100,9 +100,9 @@ custom_kg = {
"description": "CompanyD provides ServiceZ", "description": "CompanyD provides ServiceZ",
"keywords": "provide, offer", "keywords": "provide, offer",
"weight": 1.0, "weight": 1.0,
"source_id": "Source4" "source_id": "Source4",
} },
] ],
} }
rag.insert_custom_kg(custom_kg) rag.insert_custom_kg(custom_kg)

View File

@@ -1,10 +1,13 @@
from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional, Any
import sys import sys
import os import os
from pathlib import Path from pathlib import Path
import asyncio import asyncio
@@ -16,9 +19,7 @@ import numpy as np
from lightrag.kg.oracle_impl import OracleDB from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd()) print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory)) sys.path.append(os.path.abspath(script_directory))
@@ -37,14 +38,13 @@ APIKEY = "ocigenerativeai"
# Configure working directory # Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}") print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus") LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
print(f"LLM_MODEL: {LLM_MODEL}") print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0") EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512)) EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
@@ -94,10 +94,10 @@ async def init():
"user": "", "user": "",
"password": "", "password": "",
"dsn": "", "dsn": "",
"config_dir": "", "config_dir": "path_to_config_dir",
"wallet_location": "", "wallet_location": "path_to_wallet_location",
"wallet_password": "", "wallet_password": "wallet_password",
"workspace": "", "workspace": "company",
} # specify which docs you want to store and query } # specify which docs you want to store and query
) )
@@ -105,6 +105,7 @@ async def init():
await oracle_db.check_tables() await oracle_db.check_tables()
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
@@ -128,6 +129,17 @@ async def init():
return rag return rag
# Extract and Insert into LightRAG storage
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform search in different modes
# modes = ["naive", "local", "global", "hybrid"]
# for mode in modes:
# print("="*20, mode, "="*20)
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
# print("-"*100, "\n")
# Data models # Data models
@@ -135,6 +147,11 @@ class QueryRequest(BaseModel):
query: str query: str
mode: str = "hybrid" mode: str = "hybrid"
only_need_context: bool = False only_need_context: bool = False
only_need_prompt: bool = False
class DataRequest(BaseModel):
limit: int = 100
class InsertRequest(BaseModel): class InsertRequest(BaseModel):
@@ -143,7 +160,7 @@ class InsertRequest(BaseModel):
class Response(BaseModel): class Response(BaseModel):
status: str status: str
data: Optional[str] = None data: Optional[Any] = None
message: Optional[str] = None message: Optional[str] = None
@@ -167,17 +184,35 @@ app = FastAPI(
@app.post("/query", response_model=Response) @app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest): async def query_endpoint(request: QueryRequest):
try: # try:
# loop = asyncio.get_event_loop() # loop = asyncio.get_event_loop()
if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery( result = await rag.aquery(
request.query, request.query,
param=QueryParam( param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k,
), ),
) )
return Response(status="success", data=result) return Response(status="success", data=result)
except Exception as e: # except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) # raise HTTPException(status_code=500, detail=str(e))
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@app.post("/insert", response_model=Response) @app.post("/insert", response_model=Response)
@@ -220,7 +255,7 @@ async def health_check():
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020) uvicorn.run(app, host="127.0.0.1", port=8020)
# Usage example # Usage example
# To run the server, use the following command in your terminal: # To run the server, use the following command in your terminal:

View File

@@ -84,6 +84,7 @@ async def main():
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,

View File

@@ -17,9 +17,12 @@ T = TypeVar("T")
class QueryParam: class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global" mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs" response_type: str = "Multiple Paragraphs"
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60 top_k: int = 60
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks. # Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000 max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions # Number of tokens for the relationship descriptions

View File

@@ -545,6 +545,13 @@ class OracleGraphStorage(BaseGraphStorage):
if res: if res:
return res return res
async def get_statistics(self):
SQL = SQL_TEMPLATES["get_statistics"]
params = {"workspace": self.db.workspace}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
return res
N_T = { N_T = {
"full_docs": "LIGHTRAG_DOC_FULL", "full_docs": "LIGHTRAG_DOC_FULL",
@@ -717,13 +724,22 @@ SQL_TEMPLATES = {
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
"get_all_nodes": """SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content "get_all_nodes": """WITH t0 AS (
FROM LIGHTRAG_GRAPH_NODES t1 SELECT name AS id, entity_type AS label, entity_type, description,
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
WHERE t1.workspace=:workspace FROM lightrag_graph_nodes
order by t1.CREATETIME DESC WHERE workspace = :workspace
fetch first :limit rows only ORDER BY createtime DESC fetch first :limit rows only
""", ), t1 AS (
SELECT t0.id, source_chunk_id
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
), t2 AS (
SELECT t1.id, LISTAGG(t2.content, '\n') content
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
GROUP BY t1.id
)
SELECT t0.id, label, entity_type, description, t2.content
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
"get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
t1.weight,t1.DESCRIPTION,t2.content t1.weight,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_EDGES t1 FROM LIGHTRAG_GRAPH_EDGES t1
@@ -731,4 +747,13 @@ SQL_TEMPLATES = {
WHERE t1.workspace=:workspace WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC order by t1.CREATETIME DESC
fetch first :limit rows only""", fetch first :limit rows only""",
"get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
FROM (
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
UNION
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
)""",
} }

View File

@@ -13,9 +13,8 @@ from .llm import (
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities, extract_entities,
local_query, # local_query,global_query,hybrid_query,
global_query, kg_query,
hybrid_query,
naive_query, naive_query,
) )
@@ -415,28 +414,8 @@ class LightRAG:
return loop.run_until_complete(self.aquery(query, param)) return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()): async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local": if param.mode in ["local", "global", "hybrid"]:
response = await local_query( response = await kg_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "global":
response = await global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hybrid":
response = await hybrid_query(
query, query,
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,

View File

@@ -69,12 +69,15 @@ async def openai_complete_if_cache(
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
content = response.choices[0].message.content
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
# print(content)
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert( await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}} {args_hash: {"return": response.choices[0].message.content, "model": model}}
) )
return response.choices[0].message.content return content
@retry( @retry(

View File

@@ -249,6 +249,17 @@ async def extract_entities(
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items()) ordered_chunks = list(chunks.items())
# add language and example number params to prompt
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
examples = "\n".join(
PROMPTS["entity_extraction_examples"][: int(example_number)]
)
else:
examples = "\n".join(PROMPTS["entity_extraction_examples"])
entity_extract_prompt = PROMPTS["entity_extraction"] entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict( context_base = dict(
@@ -256,7 +267,10 @@ async def extract_entities(
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]), entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
examples=examples,
language=language,
) )
continue_prompt = PROMPTS["entiti_continue_extraction"] continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
@@ -271,7 +285,6 @@ async def extract_entities(
content = chunk_dp["content"] content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt) final_result = await use_llm_func(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning): for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history) glean_result = await use_llm_func(continue_prompt, history_messages=history)
@@ -414,7 +427,7 @@ async def extract_entities(
return knowledge_graph_inst return knowledge_graph_inst
async def local_query( async def kg_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
@@ -424,42 +437,63 @@ async def local_query(
global_config: dict, global_config: dict,
) -> str: ) -> str:
context = None context = None
use_model_func = global_config["llm_model_func"] example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
kw_prompt_temp = PROMPTS["keywords_extraction"] examples = "\n".join(
kw_prompt = kw_prompt_temp.format(query=query) PROMPTS["keywords_extraction_examples"][: int(example_number)]
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
try:
keywords_data = json.loads(json_text)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
) )
result = "{" + result.split("{")[-1].split("}")[0] + "}" else:
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples)
result = await use_model_func(kw_prompt)
logger.info("kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
keywords = ", ".join(keywords)
# Handle parsing error # Handle parsing error
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
if keywords:
context = await _build_local_query_context( # Handdle keywords missing
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
# Build context
keywords = [ll_keywords, hl_keywords]
context = await _build_query_context(
keywords, keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
if query_param.only_need_context: if query_param.only_need_context:
return context return context
if context is None: if context is None:
@@ -468,6 +502,8 @@ async def local_query(
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type context_data=context, response_type=query_param.response_type
) )
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -486,22 +522,114 @@ async def local_query(
return response return response
async def _build_local_query_context( async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context, ll_relations_context, ll_text_units_context = (
"",
"",
"",
)
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
query_param.mode = "global"
else:
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context, hl_relations_context, hl_text_units_context = (
"",
"",
"",
)
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
query_param.mode = "local"
else:
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.mode == "hybrid":
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context],
)
elif query_param.mode == "local":
entities_context, relations_context, text_units_context = (
ll_entities_context,
ll_relations_context,
ll_text_units_context,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = (
hl_entities_context,
hl_relations_context,
hl_text_units_context,
)
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _get_node_data(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
# get similar entities
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
return None return None
# get entity information
node_datas = await asyncio.gather( node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
) )
if not all([n is not None for n in node_datas]): if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged") logger.warning("Some nodes are missing, maybe the storage is damaged")
# get entity degree
node_degrees = await asyncio.gather( node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
) )
@@ -510,15 +638,19 @@ async def _build_local_query_context(
for k, n, d in zip(results, node_datas, node_degrees) for k, n, d in zip(results, node_datas, node_degrees)
if n is not None if n is not None
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst node_datas, query_param, text_chunks_db, knowledge_graph_inst
) )
# get relate edges
use_relations = await _find_most_related_edges_from_entities( use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst node_datas, query_param, knowledge_graph_inst
) )
logger.info( logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
) )
# build prompt
entites_section_list = [["id", "entity", "type", "description", "rank"]] entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas): for i, n in enumerate(node_datas):
entites_section_list.append( entites_section_list.append(
@@ -553,20 +685,7 @@ async def _build_local_query_context(
for i, t in enumerate(use_text_units): for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]]) text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list) text_units_context = list_of_list_to_csv(text_units_section_list)
return f""" return entities_context, relations_context, text_units_context
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
@@ -683,86 +802,9 @@ async def _find_most_related_edges_from_entities(
return all_edges_data return all_edges_data
async def global_query( async def _get_edge_data(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
try:
keywords_data = json.loads(json_text)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[-1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ", ".join(keywords)
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if keywords:
context = await _build_global_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
async def _build_global_query_context(
keywords, keywords,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
@@ -804,6 +846,7 @@ async def _build_global_query_context(
logger.info( logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
) )
relations_section_list = [ relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"] ["id", "source", "target", "description", "keywords", "weight", "rank"]
] ]
@@ -838,21 +881,7 @@ async def _build_global_query_context(
for i, t in enumerate(use_text_units): for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]]) text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list) text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context, relations_context, text_units_context
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def _find_most_related_entities_from_relationships( async def _find_most_related_entities_from_relationships(
@@ -929,134 +958,11 @@ async def _find_related_text_unit_from_relationships(
return all_text_units return all_text_units
async def hybrid_query( def combine_contexts(entities, relationships, sources):
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
json_text = locate_json_string_body_from_string(result)
try:
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError:
try:
result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[-1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ", ".join(hl_keywords)
ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"]
if ll_keywords:
low_level_context = await _build_local_query_context(
ll_keywords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param,
)
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings # Function to extract entities, relationships, and sources from context strings
hl_entities, ll_entities = entities[0], entities[1]
def extract_sections(context): hl_relationships, ll_relationships = relationships[0], relationships[1]
entities_match = re.search( hl_sources, ll_sources = sources[0], sources[1]
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources
# Extract sections from both contexts
if high_level_context is None:
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context is None:
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities # Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities) combined_entities = process_combine_contexts(hl_entities, ll_entities)
@@ -1068,21 +974,7 @@ def combine_contexts(high_level_context, low_level_context):
# Combine and deduplicate the sources # Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources) combined_sources = process_combine_contexts(hl_sources, ll_sources)
# Format the combined context return combined_entities, combined_relationships, combined_sources
return f"""
-----Entities-----
```csv
{combined_entities}
```
-----Relationships-----
```csv
{combined_relationships}
```
-----Sources-----
```csv
{combined_sources}
```
"""
async def naive_query( async def naive_query(
@@ -1105,13 +997,15 @@ async def naive_query(
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context: if query_param.only_need_context:
return section return section
sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type content_data=section, response_type=query_param.response_type
) )
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,

View File

@@ -2,6 +2,7 @@ GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {} PROMPTS = {}
PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
@@ -11,6 +12,7 @@ PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
PROMPTS["entity_extraction"] = """-Goal- PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
Use {language} as output language.
-Steps- -Steps-
1. Identify all entities. For each identified entity, extract the following information: 1. Identify all entities. For each identified entity, extract the following information:
@@ -38,7 +40,19 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_
###################### ######################
-Examples- -Examples-
###################### ######################
Example 1: {examples}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
Entity_types: [person, technology, mission, organization, location] Entity_types: [person, technology, mission, organization, location]
Text: Text:
@@ -62,8 +76,8 @@ Output:
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter} ("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
############################# #############################""",
Example 2: """Example 2:
Entity_types: [person, technology, mission, organization, location] Entity_types: [person, technology, mission, organization, location]
Text: Text:
@@ -80,8 +94,8 @@ Output:
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter} ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter} ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter} ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
############################# #############################""",
Example 3: """Example 3:
Entity_types: [person, role, technology, organization, event, location, concept] Entity_types: [person, role, technology, organization, event, location, concept]
Text: Text:
@@ -107,14 +121,8 @@ Output:
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter} ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
############################# #############################""",
-Real Data- ]
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
PROMPTS[ PROMPTS[
"summarize_entity_descriptions" "summarize_entity_descriptions"
@@ -123,6 +131,7 @@ Given one or two entities, and a list of descriptions, all related to the same e
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context. Make sure it is written in third person, and include the entity names so we the have full context.
Use Chinese as output language.
####### #######
-Data- -Data-
@@ -169,6 +178,7 @@ Add sections and commentary to the response as appropriate for the length and fo
PROMPTS["keywords_extraction"] = """---Role--- PROMPTS["keywords_extraction"] = """---Role---
You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query. You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
Use Chinese as output language.
---Goal--- ---Goal---
@@ -184,7 +194,20 @@ Given the query, list both high-level and low-level keywords. High-level keyword
###################### ######################
-Examples- -Examples-
###################### ######################
Example 1: {examples}
#############################
-Real Data-
######################
Query: {query}
######################
The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
Output:
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
Query: "How does international trade influence global economic stability?" Query: "How does international trade influence global economic stability?"
################ ################
@@ -193,8 +216,8 @@ Output:
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"], "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}} }}
############################# #############################""",
Example 2: """Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?" Query: "What are the environmental consequences of deforestation on biodiversity?"
################ ################
@@ -203,8 +226,8 @@ Output:
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}} }}
############################# #############################""",
Example 3: """Example 3:
Query: "What is the role of education in reducing poverty?" Query: "What is the role of education in reducing poverty?"
################ ################
@@ -213,14 +236,9 @@ Output:
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}} }}
############################# #############################""",
-Real Data- ]
######################
Query: {query}
######################
Output:
"""
PROMPTS["naive_rag_response"] = """---Role--- PROMPTS["naive_rag_response"] = """---Role---

View File

@@ -47,10 +47,27 @@ class EmbeddingFunc:
def locate_json_string_body_from_string(content: str) -> Union[str, None]: def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string""" """Locate the JSON string body from a string"""
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL) maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None: if maybe_json_str is not None:
return maybe_json_str.group(0) maybe_json_str = maybe_json_str.group(0)
else: maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str)
return maybe_json_str
except Exception:
pass
# try:
# content = (
# content.replace(kw_prompt[:-1], "")
# .replace("user", "")
# .replace("model", "")
# .strip()
# )
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
# json.loads(maybe_json_str)
return None return None