diff --git a/.gitignore b/.gitignore index 942c2c25..e6f5f5ba 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ ignore_this.txt .venv/ *.ignore.* .ruff_cache/ +gui/ +*.log diff --git a/README.md b/README.md index ce14e3bb..893969f9 100644 --- a/README.md +++ b/README.md @@ -555,6 +555,35 @@ if __name__ == "__main__": +### LightRAG init parameters + +| **Parameter** | **Type** | **Explanation** | **Default** | +| --- | --- | --- | --- | +| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | +| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | +| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | +| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | +| **log\_level** | | Log level for application runtime | `logging.DEBUG` | +| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | +| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | +| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | +| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` | +| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` | +| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` | +| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` | +| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` | +| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` | +| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` | +| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` | +| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` | +| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` | +| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` | +| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | | +| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | | +| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | +| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | +| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | + ## API Server Implementation LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests. diff --git a/examples/insert_custom_kg.py b/examples/insert_custom_kg.py index bbabe6a9..19da0f29 100644 --- a/examples/insert_custom_kg.py +++ b/examples/insert_custom_kg.py @@ -1,5 +1,5 @@ import os -from lightrag import LightRAG, QueryParam +from lightrag import LightRAG from lightrag.llm import gpt_4o_mini_complete ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() @@ -24,50 +24,50 @@ custom_kg = { "entity_name": "CompanyA", "entity_type": "Organization", "description": "A major technology company", - "source_id": "Source1" + "source_id": "Source1", }, { "entity_name": "ProductX", "entity_type": "Product", "description": "A popular product developed by CompanyA", - "source_id": "Source1" + "source_id": "Source1", }, { "entity_name": "PersonA", "entity_type": "Person", "description": "A renowned researcher in AI", - "source_id": "Source2" + "source_id": "Source2", }, { "entity_name": "UniversityB", "entity_type": "Organization", "description": "A leading university specializing in technology and sciences", - "source_id": "Source2" + "source_id": "Source2", }, { "entity_name": "CityC", "entity_type": "Location", "description": "A large metropolitan city known for its culture and economy", - "source_id": "Source3" + "source_id": "Source3", }, { "entity_name": "EventY", "entity_type": "Event", "description": "An annual technology conference held in CityC", - "source_id": "Source3" + "source_id": "Source3", }, { "entity_name": "CompanyD", "entity_type": "Organization", "description": "A financial services company specializing in insurance", - "source_id": "Source4" + "source_id": "Source4", }, { "entity_name": "ServiceZ", "entity_type": "Service", "description": "An insurance product offered by CompanyD", - "source_id": "Source4" - } + "source_id": "Source4", + }, ], "relationships": [ { @@ -76,7 +76,7 @@ custom_kg = { "description": "CompanyA develops ProductX", "keywords": "develop, produce", "weight": 1.0, - "source_id": "Source1" + "source_id": "Source1", }, { "src_id": "PersonA", @@ -84,7 +84,7 @@ custom_kg = { "description": "PersonA works at UniversityB", "keywords": "employment, affiliation", "weight": 0.9, - "source_id": "Source2" + "source_id": "Source2", }, { "src_id": "CityC", @@ -92,7 +92,7 @@ custom_kg = { "description": "EventY is hosted in CityC", "keywords": "host, location", "weight": 0.8, - "source_id": "Source3" + "source_id": "Source3", }, { "src_id": "CompanyD", @@ -100,9 +100,9 @@ custom_kg = { "description": "CompanyD provides ServiceZ", "keywords": "provide, offer", "weight": 1.0, - "source_id": "Source4" - } - ] + "source_id": "Source4", + }, + ], } rag.insert_custom_kg(custom_kg) diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index 9b4e2741..774ef61f 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -1,10 +1,13 @@ from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi import Query from contextlib import asynccontextmanager from pydantic import BaseModel -from typing import Optional +from typing import Optional, Any import sys import os + + from pathlib import Path import asyncio @@ -16,9 +19,7 @@ import numpy as np from lightrag.kg.oracle_impl import OracleDB - print(os.getcwd()) - script_directory = Path(__file__).resolve().parent.parent sys.path.append(os.path.abspath(script_directory)) @@ -37,14 +38,13 @@ APIKEY = "ocigenerativeai" # Configure working directory WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") print(f"WORKING_DIR: {WORKING_DIR}") -LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus") +LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024") print(f"LLM_MODEL: {LLM_MODEL}") EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512)) print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") - if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -94,10 +94,10 @@ async def init(): "user": "", "password": "", "dsn": "", - "config_dir": "", - "wallet_location": "", - "wallet_password": "", - "workspace": "", + "config_dir": "path_to_config_dir", + "wallet_location": "path_to_wallet_location", + "wallet_password": "wallet_password", + "workspace": "company", } # specify which docs you want to store and query ) @@ -105,6 +105,7 @@ async def init(): await oracle_db.check_tables() # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage + # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, @@ -128,6 +129,17 @@ async def init(): return rag +# Extract and Insert into LightRAG storage +# with open("./dickens/book.txt", "r", encoding="utf-8") as f: +# await rag.ainsert(f.read()) + +# # Perform search in different modes +# modes = ["naive", "local", "global", "hybrid"] +# for mode in modes: +# print("="*20, mode, "="*20) +# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode))) +# print("-"*100, "\n") + # Data models @@ -135,6 +147,11 @@ class QueryRequest(BaseModel): query: str mode: str = "hybrid" only_need_context: bool = False + only_need_prompt: bool = False + + +class DataRequest(BaseModel): + limit: int = 100 class InsertRequest(BaseModel): @@ -143,7 +160,7 @@ class InsertRequest(BaseModel): class Response(BaseModel): status: str - data: Optional[str] = None + data: Optional[Any] = None message: Optional[str] = None @@ -167,17 +184,35 @@ app = FastAPI( @app.post("/query", response_model=Response) async def query_endpoint(request: QueryRequest): - try: - # loop = asyncio.get_event_loop() - result = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, only_need_context=request.only_need_context - ), - ) - return Response(status="success", data=result) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + # try: + # loop = asyncio.get_event_loop() + if request.mode == "naive": + top_k = 3 + else: + top_k = 60 + result = await rag.aquery( + request.query, + param=QueryParam( + mode=request.mode, + only_need_context=request.only_need_context, + only_need_prompt=request.only_need_prompt, + top_k=top_k, + ), + ) + return Response(status="success", data=result) + # except Exception as e: + # raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/data", response_model=Response) +async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)): + if type == "nodes": + result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit) + elif type == "edges": + result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit) + elif type == "statistics": + result = await rag.chunk_entity_relation_graph.get_statistics() + return Response(status="success", data=result) @app.post("/insert", response_model=Response) @@ -220,7 +255,7 @@ async def health_check(): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8020) + uvicorn.run(app, host="127.0.0.1", port=8020) # Usage example # To run the server, use the following command in your terminal: diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 365b6225..2aa47c78 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -84,6 +84,7 @@ async def main(): # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage + # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, diff --git a/lightrag/base.py b/lightrag/base.py index 46dfc800..ea84c000 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -17,9 +17,12 @@ T = TypeVar("T") class QueryParam: mode: Literal["local", "global", "hybrid", "naive"] = "global" only_need_context: bool = False + only_need_prompt: bool = False response_type: str = "Multiple Paragraphs" # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. top_k: int = 60 + # Number of document chunks to retrieve. + # top_n: int = 10 # Number of tokens for the original chunks. max_token_for_text_unit: int = 4000 # Number of tokens for the relationship descriptions diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index b46d36d8..8ed73772 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -545,6 +545,13 @@ class OracleGraphStorage(BaseGraphStorage): if res: return res + async def get_statistics(self): + SQL = SQL_TEMPLATES["get_statistics"] + params = {"workspace": self.db.workspace} + res = await self.db.query(sql=SQL, params=params, multirows=True) + if res: + return res + N_T = { "full_docs": "LIGHTRAG_DOC_FULL", @@ -717,13 +724,22 @@ SQL_TEMPLATES = { WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, - "get_all_nodes": """SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content - FROM LIGHTRAG_GRAPH_NODES t1 - LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id - WHERE t1.workspace=:workspace - order by t1.CREATETIME DESC - fetch first :limit rows only - """, + "get_all_nodes": """WITH t0 AS ( + SELECT name AS id, entity_type AS label, entity_type, description, + '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids + FROM lightrag_graph_nodes + WHERE workspace = :workspace + ORDER BY createtime DESC fetch first :limit rows only + ), t1 AS ( + SELECT t0.id, source_chunk_id + FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) ) + ), t2 AS ( + SELECT t1.id, LISTAGG(t2.content, '\n') content + FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id + GROUP BY t1.id + ) + SELECT t0.id, label, entity_type, description, t2.content + FROM t0 LEFT JOIN t2 ON t0.id = t2.id""", "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, t1.weight,t1.DESCRIPTION,t2.content FROM LIGHTRAG_GRAPH_EDGES t1 @@ -731,4 +747,13 @@ SQL_TEMPLATES = { WHERE t1.workspace=:workspace order by t1.CREATETIME DESC fetch first :limit rows only""", + "get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count, + count(distinct CASE WHEN type='edge' THEN id END) as edges_count + FROM ( + select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph + MATCH (a) WHERE a.workspace=:workspace columns(a.name as id)) + UNION + select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph + MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) + )""", } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d531f4f6..cec21b2f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -13,9 +13,8 @@ from .llm import ( from .operate import ( chunking_by_token_size, extract_entities, - local_query, - global_query, - hybrid_query, + # local_query,global_query,hybrid_query, + kg_query, naive_query, ) @@ -415,28 +414,8 @@ class LightRAG: return loop.run_until_complete(self.aquery(query, param)) async def aquery(self, query: str, param: QueryParam = QueryParam()): - if param.mode == "local": - response = await local_query( - query, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.text_chunks, - param, - asdict(self), - ) - elif param.mode == "global": - response = await global_query( - query, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.text_chunks, - param, - asdict(self), - ) - elif param.mode == "hybrid": - response = await hybrid_query( + if param.mode in ["local", "global", "hybrid"]: + response = await kg_query( query, self.chunk_entity_relation_graph, self.entities_vdb, diff --git a/lightrag/llm.py b/lightrag/llm.py index 6cc46c85..6a191a0f 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -69,12 +69,15 @@ async def openai_complete_if_cache( response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) - + content = response.choices[0].message.content + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") + # print(content) if hashing_kv is not None: await hashing_kv.upsert( {args_hash: {"return": response.choices[0].message.content, "model": model}} ) - return response.choices[0].message.content + return content @retry( diff --git a/lightrag/operate.py b/lightrag/operate.py index 9e4b768a..c761519f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -249,6 +249,17 @@ async def extract_entities( entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] ordered_chunks = list(chunks.items()) + # add language and example number params to prompt + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number < len(PROMPTS["entity_extraction_examples"]): + examples = "\n".join( + PROMPTS["entity_extraction_examples"][: int(example_number)] + ) + else: + examples = "\n".join(PROMPTS["entity_extraction_examples"]) entity_extract_prompt = PROMPTS["entity_extraction"] context_base = dict( @@ -256,7 +267,10 @@ async def extract_entities( record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]), + examples=examples, + language=language, ) + continue_prompt = PROMPTS["entiti_continue_extraction"] if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] @@ -271,7 +285,6 @@ async def extract_entities( content = chunk_dp["content"] hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) final_result = await use_llm_func(hint_prompt) - history = pack_user_ass_to_openai_messages(hint_prompt, final_result) for now_glean_index in range(entity_extract_max_gleaning): glean_result = await use_llm_func(continue_prompt, history_messages=history) @@ -414,7 +427,7 @@ async def extract_entities( return knowledge_graph_inst -async def local_query( +async def kg_query( query, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, @@ -424,42 +437,63 @@ async def local_query( global_config: dict, ) -> str: context = None - use_model_func = global_config["llm_model_func"] - - kw_prompt_temp = PROMPTS["keywords_extraction"] - kw_prompt = kw_prompt_temp.format(query=query) - result = await use_model_func(kw_prompt) - json_text = locate_json_string_body_from_string(result) - - try: - keywords_data = json.loads(json_text) - keywords = keywords_data.get("low_level_keywords", []) - keywords = ", ".join(keywords) - except json.JSONDecodeError: - try: - result = ( - result.replace(kw_prompt[:-1], "") - .replace("user", "") - .replace("model", "") - .strip() - ) - result = "{" + result.split("{")[-1].split("}")[0] + "}" - - keywords_data = json.loads(result) - keywords = keywords_data.get("low_level_keywords", []) - keywords = ", ".join(keywords) - # Handle parsing error - except json.JSONDecodeError as e: - print(f"JSON parsing error: {e}") - return PROMPTS["fail_response"] - if keywords: - context = await _build_local_query_context( - keywords, - knowledge_graph_inst, - entities_vdb, - text_chunks_db, - query_param, + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): + examples = "\n".join( + PROMPTS["keywords_extraction_examples"][: int(example_number)] ) + else: + examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + + # Set mode + if query_param.mode not in ["local", "global", "hybrid"]: + logger.error(f"Unknown mode {query_param.mode} in kg_query") + return PROMPTS["fail_response"] + + # LLM generate keywords + use_model_func = global_config["llm_model_func"] + kw_prompt_temp = PROMPTS["keywords_extraction"] + kw_prompt = kw_prompt_temp.format(query=query, examples=examples) + result = await use_model_func(kw_prompt) + logger.info("kw_prompt result:") + print(result) + try: + json_text = locate_json_string_body_from_string(result) + keywords_data = json.loads(json_text) + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + + # Handle parsing error + except json.JSONDecodeError as e: + print(f"JSON parsing error: {e} {result}") + return PROMPTS["fail_response"] + + # Handdle keywords missing + if hl_keywords == [] and ll_keywords == []: + logger.warning("low_level_keywords and high_level_keywords is empty") + return PROMPTS["fail_response"] + if ll_keywords == [] and query_param.mode in ["local", "hybrid"]: + logger.warning("low_level_keywords is empty") + return PROMPTS["fail_response"] + else: + ll_keywords = ", ".join(ll_keywords) + if hl_keywords == [] and query_param.mode in ["global", "hybrid"]: + logger.warning("high_level_keywords is empty") + return PROMPTS["fail_response"] + else: + hl_keywords = ", ".join(hl_keywords) + + # Build context + keywords = [ll_keywords, hl_keywords] + context = await _build_query_context( + keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + ) + if query_param.only_need_context: return context if context is None: @@ -468,6 +502,8 @@ async def local_query( sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) + if query_param.only_need_prompt: + return sys_prompt response = await use_model_func( query, system_prompt=sys_prompt, @@ -486,22 +522,114 @@ async def local_query( return response -async def _build_local_query_context( +async def _build_query_context( + query: list, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, +): + ll_kewwords, hl_keywrds = query[0], query[1] + if query_param.mode in ["local", "hybrid"]: + if ll_kewwords == "": + ll_entities_context, ll_relations_context, ll_text_units_context = ( + "", + "", + "", + ) + warnings.warn( + "Low Level context is None. Return empty Low entity/relationship/source" + ) + query_param.mode = "global" + else: + ( + ll_entities_context, + ll_relations_context, + ll_text_units_context, + ) = await _get_node_data( + ll_kewwords, + knowledge_graph_inst, + entities_vdb, + text_chunks_db, + query_param, + ) + if query_param.mode in ["global", "hybrid"]: + if hl_keywrds == "": + hl_entities_context, hl_relations_context, hl_text_units_context = ( + "", + "", + "", + ) + warnings.warn( + "High Level context is None. Return empty High entity/relationship/source" + ) + query_param.mode = "local" + else: + ( + hl_entities_context, + hl_relations_context, + hl_text_units_context, + ) = await _get_edge_data( + hl_keywrds, + knowledge_graph_inst, + relationships_vdb, + text_chunks_db, + query_param, + ) + if query_param.mode == "hybrid": + entities_context, relations_context, text_units_context = combine_contexts( + [hl_entities_context, ll_entities_context], + [hl_relations_context, ll_relations_context], + [hl_text_units_context, ll_text_units_context], + ) + elif query_param.mode == "local": + entities_context, relations_context, text_units_context = ( + ll_entities_context, + ll_relations_context, + ll_text_units_context, + ) + elif query_param.mode == "global": + entities_context, relations_context, text_units_context = ( + hl_entities_context, + hl_relations_context, + hl_text_units_context, + ) + return f""" +-----Entities----- +```csv +{entities_context} +``` +-----Relationships----- +```csv +{relations_context} +``` +-----Sources----- +```csv +{text_units_context} +``` +""" + + +async def _get_node_data( query, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, ): + # get similar entities results = await entities_vdb.query(query, top_k=query_param.top_k) - if not len(results): return None + # get entity information node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] ) if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") + + # get entity degree node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] ) @@ -510,15 +638,19 @@ async def _build_local_query_context( for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. + # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst ) + # get relate edges use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst ) logger.info( f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" ) + + # build prompt entites_section_list = [["id", "entity", "type", "description", "rank"]] for i, n in enumerate(node_datas): entites_section_list.append( @@ -553,20 +685,7 @@ async def _build_local_query_context( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) - return f""" ------Entities----- -```csv -{entities_context} -``` ------Relationships----- -```csv -{relations_context} -``` ------Sources----- -```csv -{text_units_context} -``` -""" + return entities_context, relations_context, text_units_context async def _find_most_related_text_unit_from_entities( @@ -683,86 +802,9 @@ async def _find_most_related_edges_from_entities( return all_edges_data -async def global_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, -) -> str: - context = None - use_model_func = global_config["llm_model_func"] - - kw_prompt_temp = PROMPTS["keywords_extraction"] - kw_prompt = kw_prompt_temp.format(query=query) - result = await use_model_func(kw_prompt) - json_text = locate_json_string_body_from_string(result) - - try: - keywords_data = json.loads(json_text) - keywords = keywords_data.get("high_level_keywords", []) - keywords = ", ".join(keywords) - except json.JSONDecodeError: - try: - result = ( - result.replace(kw_prompt[:-1], "") - .replace("user", "") - .replace("model", "") - .strip() - ) - result = "{" + result.split("{")[-1].split("}")[0] + "}" - - keywords_data = json.loads(result) - keywords = keywords_data.get("high_level_keywords", []) - keywords = ", ".join(keywords) - - except json.JSONDecodeError as e: - # Handle parsing error - print(f"JSON parsing error: {e}") - return PROMPTS["fail_response"] - if keywords: - context = await _build_global_query_context( - keywords, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - text_chunks_db, - query_param, - ) - - if query_param.only_need_context: - return context - if context is None: - return PROMPTS["fail_response"] - - sys_prompt_temp = PROMPTS["rag_response"] - sys_prompt = sys_prompt_temp.format( - context_data=context, response_type=query_param.response_type - ) - response = await use_model_func( - query, - system_prompt=sys_prompt, - ) - if len(response) > len(sys_prompt): - response = ( - response.replace(sys_prompt, "") - .replace("user", "") - .replace("model", "") - .replace(query, "") - .replace("", "") - .replace("", "") - .strip() - ) - - return response - - -async def _build_global_query_context( +async def _get_edge_data( keywords, knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, @@ -804,6 +846,7 @@ async def _build_global_query_context( logger.info( f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" ) + relations_section_list = [ ["id", "source", "target", "description", "keywords", "weight", "rank"] ] @@ -838,21 +881,7 @@ async def _build_global_query_context( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) - - return f""" ------Entities----- -```csv -{entities_context} -``` ------Relationships----- -```csv -{relations_context} -``` ------Sources----- -```csv -{text_units_context} -``` -""" + return entities_context, relations_context, text_units_context async def _find_most_related_entities_from_relationships( @@ -929,134 +958,11 @@ async def _find_related_text_unit_from_relationships( return all_text_units -async def hybrid_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, -) -> str: - low_level_context = None - high_level_context = None - use_model_func = global_config["llm_model_func"] - - kw_prompt_temp = PROMPTS["keywords_extraction"] - kw_prompt = kw_prompt_temp.format(query=query) - - result = await use_model_func(kw_prompt) - json_text = locate_json_string_body_from_string(result) - try: - keywords_data = json.loads(json_text) - hl_keywords = keywords_data.get("high_level_keywords", []) - ll_keywords = keywords_data.get("low_level_keywords", []) - hl_keywords = ", ".join(hl_keywords) - ll_keywords = ", ".join(ll_keywords) - except json.JSONDecodeError: - try: - result = ( - result.replace(kw_prompt[:-1], "") - .replace("user", "") - .replace("model", "") - .strip() - ) - result = "{" + result.split("{")[-1].split("}")[0] + "}" - keywords_data = json.loads(result) - hl_keywords = keywords_data.get("high_level_keywords", []) - ll_keywords = keywords_data.get("low_level_keywords", []) - hl_keywords = ", ".join(hl_keywords) - ll_keywords = ", ".join(ll_keywords) - # Handle parsing error - except json.JSONDecodeError as e: - print(f"JSON parsing error: {e}") - return PROMPTS["fail_response"] - - if ll_keywords: - low_level_context = await _build_local_query_context( - ll_keywords, - knowledge_graph_inst, - entities_vdb, - text_chunks_db, - query_param, - ) - - if hl_keywords: - high_level_context = await _build_global_query_context( - hl_keywords, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - text_chunks_db, - query_param, - ) - - context = combine_contexts(high_level_context, low_level_context) - - if query_param.only_need_context: - return context - if context is None: - return PROMPTS["fail_response"] - - sys_prompt_temp = PROMPTS["rag_response"] - sys_prompt = sys_prompt_temp.format( - context_data=context, response_type=query_param.response_type - ) - response = await use_model_func( - query, - system_prompt=sys_prompt, - ) - if len(response) > len(sys_prompt): - response = ( - response.replace(sys_prompt, "") - .replace("user", "") - .replace("model", "") - .replace(query, "") - .replace("", "") - .replace("", "") - .strip() - ) - return response - - -def combine_contexts(high_level_context, low_level_context): +def combine_contexts(entities, relationships, sources): # Function to extract entities, relationships, and sources from context strings - - def extract_sections(context): - entities_match = re.search( - r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL - ) - relationships_match = re.search( - r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL - ) - sources_match = re.search( - r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL - ) - - entities = entities_match.group(1) if entities_match else "" - relationships = relationships_match.group(1) if relationships_match else "" - sources = sources_match.group(1) if sources_match else "" - - return entities, relationships, sources - - # Extract sections from both contexts - - if high_level_context is None: - warnings.warn( - "High Level context is None. Return empty High entity/relationship/source" - ) - hl_entities, hl_relationships, hl_sources = "", "", "" - else: - hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) - - if low_level_context is None: - warnings.warn( - "Low Level context is None. Return empty Low entity/relationship/source" - ) - ll_entities, ll_relationships, ll_sources = "", "", "" - else: - ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) - + hl_entities, ll_entities = entities[0], entities[1] + hl_relationships, ll_relationships = relationships[0], relationships[1] + hl_sources, ll_sources = sources[0], sources[1] # Combine and deduplicate the entities combined_entities = process_combine_contexts(hl_entities, ll_entities) @@ -1068,21 +974,7 @@ def combine_contexts(high_level_context, low_level_context): # Combine and deduplicate the sources combined_sources = process_combine_contexts(hl_sources, ll_sources) - # Format the combined context - return f""" ------Entities----- -```csv -{combined_entities} -``` ------Relationships----- -```csv -{combined_relationships} -``` ------Sources----- -```csv -{combined_sources} -``` -""" + return combined_entities, combined_relationships, combined_sources async def naive_query( @@ -1105,13 +997,15 @@ async def naive_query( max_token_size=query_param.max_token_for_text_unit, ) logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") - section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) + section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) if query_param.only_need_context: return section sys_prompt_temp = PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( content_data=section, response_type=query_param.response_type ) + if query_param.only_need_prompt: + return sys_prompt response = await use_model_func( query, system_prompt=sys_prompt, diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 5de116b3..0d4e599d 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -2,6 +2,7 @@ GRAPH_FIELD_SEP = "" PROMPTS = {} +PROMPTS["DEFAULT_LANGUAGE"] = "English" PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" @@ -11,6 +12,7 @@ PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"] PROMPTS["entity_extraction"] = """-Goal- Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. +Use {language} as output language. -Steps- 1. Identify all entities. For each identified entity, extract the following information: @@ -38,7 +40,19 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter} Union[str, None]: """Locate the JSON string body from a string""" - maybe_json_str = re.search(r"{.*}", content, re.DOTALL) - if maybe_json_str is not None: - return maybe_json_str.group(0) - else: + try: + maybe_json_str = re.search(r"{.*}", content, re.DOTALL) + if maybe_json_str is not None: + maybe_json_str = maybe_json_str.group(0) + maybe_json_str = maybe_json_str.replace("\\n", "") + maybe_json_str = maybe_json_str.replace("\n", "") + maybe_json_str = maybe_json_str.replace("'", '"') + json.loads(maybe_json_str) + return maybe_json_str + except Exception: + pass + # try: + # content = ( + # content.replace(kw_prompt[:-1], "") + # .replace("user", "") + # .replace("model", "") + # .strip() + # ) + # maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}" + # json.loads(maybe_json_str) + return None