Logic Optimization

This commit is contained in:
jin
2024-11-25 13:40:38 +08:00
parent bf5815be8f
commit 21f161390a
8 changed files with 185 additions and 136 deletions

2
.gitignore vendored
View File

@@ -13,4 +13,4 @@ ignore_this.txt
*.ignore.*
.ruff_cache/
gui/
*.log
*.log

View File

@@ -1,16 +1,14 @@
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager
from pydantic import BaseModel
from typing import Optional,Any
from fastapi.responses import JSONResponse
from typing import Optional, Any
import sys
import os
import sys, os
print(os.getcwd())
from pathlib import Path
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
import asyncio
import nest_asyncio
@@ -18,10 +16,12 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
# Apply nest_asyncio to solve event loop issues
@@ -47,7 +47,8 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -77,10 +78,10 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1]
return embedding_dim
async def init():
# Detect embedding dimension
embedding_dimension = 1024 #await get_embedding_dim()
embedding_dimension = 1024 # await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
@@ -88,36 +89,36 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "path_to_config_dir",
"wallet_location": "path_to_wallet_location",
"wallet_password": "wallet_password",
"workspace": "company",
} # specify which docs you want to store and query
)
oracle_db = OracleDB(config={
"user":"",
"password":"",
"dsn":"",
"config_dir":"path_to_config_dir",
"wallet_location":"path_to_wallet_location",
"wallet_password":"wallet_password",
"workspace":"company"
} # specify which docs you want to store and query
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage"
)
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
@@ -128,7 +129,7 @@ async def init():
# Extract and Insert into LightRAG storage
#with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform search in different modes
@@ -147,9 +148,11 @@ class QueryRequest(BaseModel):
only_need_context: bool = False
only_need_prompt: bool = False
class DataRequest(BaseModel):
limit: int = 100
class InsertRequest(BaseModel):
text: str
@@ -164,6 +167,7 @@ class Response(BaseModel):
rag = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
@@ -172,25 +176,28 @@ async def lifespan(app: FastAPI):
yield
app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
#try:
# loop = asyncio.get_event_loop()
# try:
# loop = asyncio.get_event_loop()
if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k
),
)
request.query,
param=QueryParam(
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k,
),
)
return Response(status="success", data=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
@@ -199,9 +206,9 @@ async def query_endpoint(request: QueryRequest):
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit)
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit)
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@@ -264,4 +271,4 @@ if __name__ == "__main__":
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -97,8 +97,7 @@ async def main():
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"},
addon_params={"example_number": 1, "language": "Simplfied Chinese"},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool

View File

@@ -114,7 +114,9 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database")
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
async def query(
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
@@ -256,7 +258,7 @@ class OracleKVStorage(BaseKVStorage):
item["__vector__"],
]
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, values)
if self.namespace == "full_docs":
for k, v in self._data.items():
@@ -266,7 +268,7 @@ class OracleKVStorage(BaseKVStorage):
)
values = [k, self._data[k]["content"], self.db.workspace]
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, values)
return left_data
async def index_done_callback(self):

View File

@@ -70,8 +70,8 @@ async def openai_complete_if_cache(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
if r'\u' in content:
content = content.encode('utf-8').decode('unicode_escape')
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
print(content)
if hashing_kv is not None:
await hashing_kv.upsert(
@@ -542,7 +542,7 @@ async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
@@ -551,7 +551,7 @@ async def openai_embedding(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

View File

@@ -249,13 +249,17 @@ async def extract_entities(
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number<len(PROMPTS["entity_extraction_examples"]):
examples="\n".join(PROMPTS["entity_extraction_examples"][:int(example_number)])
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
examples = "\n".join(
PROMPTS["entity_extraction_examples"][: int(example_number)]
)
else:
examples="\n".join(PROMPTS["entity_extraction_examples"])
examples = "\n".join(PROMPTS["entity_extraction_examples"])
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
@@ -263,8 +267,9 @@ async def extract_entities(
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
examples=examples,
language=language)
language=language,
)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
@@ -396,6 +401,7 @@ async def extract_entities(
return knowledge_graph_inst
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -408,59 +414,61 @@ async def kg_query(
context = None
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)])
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples="\n".join(PROMPTS["keywords_extraction_examples"])
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query,examples=examples)
result = await use_model_func(kw_prompt)
logger.info(f"kw_prompt result:")
kw_prompt = kw_prompt_temp.format(query=query, examples=examples)
result = await use_model_func(kw_prompt)
logger.info("kw_prompt result:")
print(result)
try:
json_text = locate_json_string_body_from_string(result)
keywords_data = json.loads(json_text)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"]
# Handdle keywords missing
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local","hybrid"]:
return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty")
return PROMPTS["fail_response"]
else:
ll_keywords = ", ".join(ll_keywords)
if hl_keywords == [] and query_param.mode in ["global","hybrid"]:
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty")
return PROMPTS["fail_response"]
else:
hl_keywords = ", ".join(hl_keywords)
# Build context
keywords = [ll_keywords, hl_keywords]
keywords = [ll_keywords, hl_keywords]
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
@@ -468,13 +476,13 @@ async def kg_query(
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -496,44 +504,72 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
):
ll_kewwords, hl_keywrds = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "":
ll_entities_context,ll_relations_context,ll_text_units_context = "","",""
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
ll_entities_context, ll_relations_context, ll_text_units_context = (
"",
"",
"",
)
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
query_param.mode = "global"
else:
ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data(
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_kewwords,
knowledge_graph_inst,
entities_vdb,
text_chunks_db,
query_param
)
query_param,
)
if query_param.mode in ["global", "hybrid"]:
if hl_keywrds == "":
hl_entities_context,hl_relations_context,hl_text_units_context = "","",""
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities_context, hl_relations_context, hl_text_units_context = (
"",
"",
"",
)
warnings.warn(
"High Level context is None. Return empty High entity/relationship/source"
)
query_param.mode = "local"
else:
hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data(
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywrds,
knowledge_graph_inst,
relationships_vdb,
text_chunks_db,
query_param
)
if query_param.mode == 'hybrid':
entities_context,relations_context,text_units_context = combine_contexts(
[hl_entities_context,ll_entities_context],
[hl_relations_context,ll_relations_context],
[hl_text_units_context,ll_text_units_context]
)
elif query_param.mode == 'local':
entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context
elif query_param.mode == 'global':
entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context
query_param,
)
if query_param.mode == "hybrid":
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context],
)
elif query_param.mode == "local":
entities_context, relations_context, text_units_context = (
ll_entities_context,
ll_relations_context,
ll_text_units_context,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = (
hl_entities_context,
hl_relations_context,
hl_text_units_context,
)
return f"""
# -----Entities-----
# ```csv
@@ -550,7 +586,6 @@ async def _build_query_context(
# """
async def _get_node_data(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -568,7 +603,7 @@ async def _get_node_data(
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
# 获取实体的度
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
@@ -588,7 +623,7 @@ async def _get_node_data(
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
)
)
# 构建提示词
entites_section_list = [["id", "entity", "type", "description", "rank"]]
@@ -625,7 +660,7 @@ async def _get_node_data(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return entities_context, relations_context, text_units_context
async def _find_most_related_text_unit_from_entities(
@@ -821,8 +856,7 @@ async def _get_edge_data(
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context,relations_context,text_units_context
return entities_context, relations_context, text_units_context
async def _find_most_related_entities_from_relationships(
@@ -902,7 +936,7 @@ async def _find_related_text_unit_from_relationships(
def combine_contexts(entities, relationships, sources):
# Function to extract entities, relationships, and sources from context strings
hl_entities, ll_entities = entities[0], entities[1]
hl_relationships, ll_relationships = relationships[0],relationships[1]
hl_relationships, ll_relationships = relationships[0], relationships[1]
hl_sources, ll_sources = sources[0], sources[1]
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)

View File

@@ -52,7 +52,7 @@ Output:
"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
"""Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
@@ -77,7 +77,7 @@ Output:
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
#############################""",
"""Example 2:
"""Example 2:
Entity_types: [person, technology, mission, organization, location]
Text:
@@ -95,7 +95,7 @@ Output:
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
#############################""",
"""Example 3:
"""Example 3:
Entity_types: [person, role, technology, organization, event, location, concept]
Text:
@@ -121,10 +121,12 @@ Output:
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
#############################"""
#############################""",
]
PROMPTS["summarize_entity_descriptions"] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
PROMPTS[
"summarize_entity_descriptions"
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
@@ -139,10 +141,14 @@ Description List: {description_list}
Output:
"""
PROMPTS["entiti_continue_extraction"] = """MANY entities were missed in the last extraction. Add them below using the same format:
PROMPTS[
"entiti_continue_extraction"
] = """MANY entities were missed in the last extraction. Add them below using the same format:
"""
PROMPTS["entiti_if_loop_extraction"] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
PROMPTS[
"entiti_if_loop_extraction"
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
"""
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
@@ -201,7 +207,7 @@ Output:
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
"""Example 1:
Query: "How does international trade influence global economic stability?"
################
@@ -211,7 +217,7 @@ Output:
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}}
#############################""",
"""Example 2:
"""Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?"
################
@@ -220,8 +226,8 @@ Output:
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}}
#############################""",
"""Example 3:
#############################""",
"""Example 3:
Query: "What is the role of education in reducing poverty?"
################
@@ -230,8 +236,8 @@ Output:
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}}
#############################"""
]
#############################""",
]
PROMPTS["naive_rag_response"] = """---Role---

View File

@@ -56,7 +56,8 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
maybe_json_str = maybe_json_str.replace("'", '"')
json.loads(maybe_json_str)
return maybe_json_str
except:
except Exception:
pass
# try:
# content = (
# content.replace(kw_prompt[:-1], "")
@@ -64,9 +65,9 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
# .replace("model", "")
# .strip()
# )
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
# json.loads(maybe_json_str)
return None