diff --git a/.gitignore b/.gitignore index 01e145a8..e6f5f5ba 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ ignore_this.txt *.ignore.* .ruff_cache/ gui/ -*.log \ No newline at end of file +*.log diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index c06b8a83..8aaa2cf5 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -1,16 +1,14 @@ - from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi import Query from contextlib import asynccontextmanager from pydantic import BaseModel -from typing import Optional,Any -from fastapi.responses import JSONResponse +from typing import Optional, Any + +import sys +import os + -import sys, os -print(os.getcwd()) from pathlib import Path -script_directory = Path(__file__).resolve().parent.parent -sys.path.append(os.path.abspath(script_directory)) import asyncio import nest_asyncio @@ -18,10 +16,12 @@ from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np -from datetime import datetime from lightrag.kg.oracle_impl import OracleDB +print(os.getcwd()) +script_directory = Path(__file__).resolve().parent.parent +sys.path.append(os.path.abspath(script_directory)) # Apply nest_asyncio to solve event loop issues @@ -47,7 +47,8 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) - + + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -77,10 +78,10 @@ async def get_embedding_dim(): embedding_dim = embedding.shape[1] return embedding_dim + async def init(): - # Detect embedding dimension - embedding_dimension = 1024 #await get_embedding_dim() + embedding_dimension = 1024 # await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") # Create Oracle DB connection # The `config` parameter is the connection configuration of Oracle DB @@ -88,36 +89,36 @@ async def init(): # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud + oracle_db = OracleDB( + config={ + "user": "", + "password": "", + "dsn": "", + "config_dir": "path_to_config_dir", + "wallet_location": "path_to_wallet_location", + "wallet_password": "wallet_password", + "workspace": "company", + } # specify which docs you want to store and query + ) - oracle_db = OracleDB(config={ - "user":"", - "password":"", - "dsn":"", - "config_dir":"path_to_config_dir", - "wallet_location":"path_to_wallet_location", - "wallet_password":"wallet_password", - "workspace":"company" - } # specify which docs you want to store and query - ) - # Check if Oracle DB tables exist, if not, tables will be created await oracle_db.check_tables() # Initialize LightRAG - # We use Oracle DB as the KV/vector/graph storage + # We use Oracle DB as the KV/vector/graph storage rag = LightRAG( - enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=512, - func=embedding_func, - ), - graph_storage = "OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage" - ) + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage", + ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.graph_storage_cls.db = oracle_db @@ -128,7 +129,7 @@ async def init(): # Extract and Insert into LightRAG storage -#with open("./dickens/book.txt", "r", encoding="utf-8") as f: +# with open("./dickens/book.txt", "r", encoding="utf-8") as f: # await rag.ainsert(f.read()) # # Perform search in different modes @@ -147,9 +148,11 @@ class QueryRequest(BaseModel): only_need_context: bool = False only_need_prompt: bool = False + class DataRequest(BaseModel): limit: int = 100 + class InsertRequest(BaseModel): text: str @@ -164,6 +167,7 @@ class Response(BaseModel): rag = None + @asynccontextmanager async def lifespan(app: FastAPI): global rag @@ -172,25 +176,28 @@ async def lifespan(app: FastAPI): yield -app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan) +app = FastAPI( + title="LightRAG API", description="API for RAG operations", lifespan=lifespan +) + @app.post("/query", response_model=Response) async def query_endpoint(request: QueryRequest): - #try: - # loop = asyncio.get_event_loop() + # try: + # loop = asyncio.get_event_loop() if request.mode == "naive": top_k = 3 else: top_k = 60 result = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - only_need_context=request.only_need_context, - only_need_prompt=request.only_need_prompt, - top_k=top_k - ), - ) + request.query, + param=QueryParam( + mode=request.mode, + only_need_context=request.only_need_context, + only_need_prompt=request.only_need_prompt, + top_k=top_k, + ), + ) return Response(status="success", data=result) # except Exception as e: # raise HTTPException(status_code=500, detail=str(e)) @@ -199,9 +206,9 @@ async def query_endpoint(request: QueryRequest): @app.get("/data", response_model=Response) async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)): if type == "nodes": - result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit) + result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit) elif type == "edges": - result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit) + result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit) elif type == "statistics": result = await rag.chunk_entity_relation_graph.get_statistics() return Response(status="success", data=result) @@ -264,4 +271,4 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file +# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index b915c76b..630c1fd8 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -97,8 +97,7 @@ async def main(): graph_storage="OracleGraphStorage", kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", - - addon_params = {"example_number":1, "language":"Simplfied Chinese"}, + addon_params={"example_number": 1, "language": "Simplfied Chinese"}, ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 2cfbd249..08ce79d5 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,7 +114,9 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: + async def query( + self, sql: str, params: dict = None, multirows: bool = False + ) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler @@ -256,7 +258,7 @@ class OracleKVStorage(BaseKVStorage): item["__vector__"], ] # print(merge_sql) - await self.db.execute(merge_sql, data) + await self.db.execute(merge_sql, values) if self.namespace == "full_docs": for k, v in self._data.items(): @@ -266,7 +268,7 @@ class OracleKVStorage(BaseKVStorage): ) values = [k, self._data[k]["content"], self.db.workspace] # print(merge_sql) - await self.db.execute(merge_sql, data) + await self.db.execute(merge_sql, values) return left_data async def index_done_callback(self): diff --git a/lightrag/llm.py b/lightrag/llm.py index 1acf07e0..d3729941 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -70,8 +70,8 @@ async def openai_complete_if_cache( model=model, messages=messages, **kwargs ) content = response.choices[0].message.content - if r'\u' in content: - content = content.encode('utf-8').decode('unicode_escape') + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") print(content) if hashing_kv is not None: await hashing_kv.upsert( @@ -542,7 +542,7 @@ async def openai_embedding( texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, - api_key: str = None + api_key: str = None, ) -> np.ndarray: if api_key: os.environ["OPENAI_API_KEY"] = api_key @@ -551,7 +551,7 @@ async def openai_embedding( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="float" + model=model, input=texts, encoding_format="float" ) return np.array([dp.embedding for dp in response.data]) diff --git a/lightrag/operate.py b/lightrag/operate.py index 1071f8c2..c4740e70 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -249,13 +249,17 @@ async def extract_entities( ordered_chunks = list(chunks.items()) # add language and example number params to prompt - language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) example_number = global_config["addon_params"].get("example_number", None) - if example_number and example_number len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -496,44 +504,72 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, - ): +): ll_kewwords, hl_keywrds = query[0], query[1] if query_param.mode in ["local", "hybrid"]: if ll_kewwords == "": - ll_entities_context,ll_relations_context,ll_text_units_context = "","","" - warnings.warn("Low Level context is None. Return empty Low entity/relationship/source") + ll_entities_context, ll_relations_context, ll_text_units_context = ( + "", + "", + "", + ) + warnings.warn( + "Low Level context is None. Return empty Low entity/relationship/source" + ) query_param.mode = "global" else: - ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data( + ( + ll_entities_context, + ll_relations_context, + ll_text_units_context, + ) = await _get_node_data( ll_kewwords, knowledge_graph_inst, entities_vdb, text_chunks_db, - query_param - ) + query_param, + ) if query_param.mode in ["global", "hybrid"]: if hl_keywrds == "": - hl_entities_context,hl_relations_context,hl_text_units_context = "","","" - warnings.warn("High Level context is None. Return empty High entity/relationship/source") + hl_entities_context, hl_relations_context, hl_text_units_context = ( + "", + "", + "", + ) + warnings.warn( + "High Level context is None. Return empty High entity/relationship/source" + ) query_param.mode = "local" else: - hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data( + ( + hl_entities_context, + hl_relations_context, + hl_text_units_context, + ) = await _get_edge_data( hl_keywrds, knowledge_graph_inst, relationships_vdb, text_chunks_db, - query_param - ) - if query_param.mode == 'hybrid': - entities_context,relations_context,text_units_context = combine_contexts( - [hl_entities_context,ll_entities_context], - [hl_relations_context,ll_relations_context], - [hl_text_units_context,ll_text_units_context] - ) - elif query_param.mode == 'local': - entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context - elif query_param.mode == 'global': - entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context + query_param, + ) + if query_param.mode == "hybrid": + entities_context, relations_context, text_units_context = combine_contexts( + [hl_entities_context, ll_entities_context], + [hl_relations_context, ll_relations_context], + [hl_text_units_context, ll_text_units_context], + ) + elif query_param.mode == "local": + entities_context, relations_context, text_units_context = ( + ll_entities_context, + ll_relations_context, + ll_text_units_context, + ) + elif query_param.mode == "global": + entities_context, relations_context, text_units_context = ( + hl_entities_context, + hl_relations_context, + hl_text_units_context, + ) return f""" # -----Entities----- # ```csv @@ -550,7 +586,6 @@ async def _build_query_context( # """ - async def _get_node_data( query, knowledge_graph_inst: BaseGraphStorage, @@ -568,7 +603,7 @@ async def _get_node_data( ) if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") - + # 获取实体的度 node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] @@ -588,7 +623,7 @@ async def _get_node_data( ) logger.info( f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" - ) + ) # 构建提示词 entites_section_list = [["id", "entity", "type", "description", "rank"]] @@ -625,7 +660,7 @@ async def _get_node_data( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) - return entities_context,relations_context,text_units_context + return entities_context, relations_context, text_units_context async def _find_most_related_text_unit_from_entities( @@ -821,8 +856,7 @@ async def _get_edge_data( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) - return entities_context,relations_context,text_units_context - + return entities_context, relations_context, text_units_context async def _find_most_related_entities_from_relationships( @@ -902,7 +936,7 @@ async def _find_related_text_unit_from_relationships( def combine_contexts(entities, relationships, sources): # Function to extract entities, relationships, and sources from context strings hl_entities, ll_entities = entities[0], entities[1] - hl_relationships, ll_relationships = relationships[0],relationships[1] + hl_relationships, ll_relationships = relationships[0], relationships[1] hl_sources, ll_sources = sources[0], sources[1] # Combine and deduplicate the entities combined_entities = process_combine_contexts(hl_entities, ll_entities) diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 389f45f2..0d4e599d 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -52,7 +52,7 @@ Output: """ PROMPTS["entity_extraction_examples"] = [ -"""Example 1: + """Example 1: Entity_types: [person, technology, mission, organization, location] Text: @@ -77,7 +77,7 @@ Output: ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} #############################""", -"""Example 2: + """Example 2: Entity_types: [person, technology, mission, organization, location] Text: @@ -95,7 +95,7 @@ Output: ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter} ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter} #############################""", -"""Example 3: + """Example 3: Entity_types: [person, role, technology, organization, event, location, concept] Text: @@ -121,10 +121,12 @@ Output: ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter} -#############################""" +#############################""", ] -PROMPTS["summarize_entity_descriptions"] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +PROMPTS[ + "summarize_entity_descriptions" +] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. @@ -139,10 +141,14 @@ Description List: {description_list} Output: """ -PROMPTS["entiti_continue_extraction"] = """MANY entities were missed in the last extraction. Add them below using the same format: +PROMPTS[ + "entiti_continue_extraction" +] = """MANY entities were missed in the last extraction. Add them below using the same format: """ -PROMPTS["entiti_if_loop_extraction"] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added. +PROMPTS[ + "entiti_if_loop_extraction" +] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added. """ PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." @@ -201,7 +207,7 @@ Output: """ PROMPTS["keywords_extraction_examples"] = [ - """Example 1: + """Example 1: Query: "How does international trade influence global economic stability?" ################ @@ -211,7 +217,7 @@ Output: "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] }} #############################""", - """Example 2: + """Example 2: Query: "What are the environmental consequences of deforestation on biodiversity?" ################ @@ -220,8 +226,8 @@ Output: "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] }} -#############################""", - """Example 3: +#############################""", + """Example 3: Query: "What is the role of education in reducing poverty?" ################ @@ -230,8 +236,8 @@ Output: "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] }} -#############################""" -] +#############################""", +] PROMPTS["naive_rag_response"] = """---Role--- diff --git a/lightrag/utils.py b/lightrag/utils.py index fc739002..bdd1aa9e 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -56,7 +56,8 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: maybe_json_str = maybe_json_str.replace("'", '"') json.loads(maybe_json_str) return maybe_json_str - except: + except Exception: + pass # try: # content = ( # content.replace(kw_prompt[:-1], "") @@ -64,9 +65,9 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: # .replace("model", "") # .strip() # ) - # maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}" + # maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}" # json.loads(maybe_json_str) - + return None