From 79646fced8612187ec26e013eb34d19c210e4908 Mon Sep 17 00:00:00 2001 From: xYLiuuuuuu Date: Mon, 6 Jan 2025 16:54:53 +0800 Subject: [PATCH 01/38] Fix:Optimized logic for automatic switching modes when keywords do not exist --- lightrag/operate.py | 117 ++++++++++++++++---------------------------- 1 file changed, 42 insertions(+), 75 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index f21e41ff..c8e4565c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -522,15 +522,16 @@ async def kg_query( logger.warning("low_level_keywords and high_level_keywords is empty") return PROMPTS["fail_response"] if ll_keywords == [] and query_param.mode in ["local", "hybrid"]: - logger.warning("low_level_keywords is empty") - return PROMPTS["fail_response"] - else: - ll_keywords = ", ".join(ll_keywords) + logger.warning("low_level_keywords is empty, switching from %s mode to global mode", query_param.mode) + query_param.mode = "global" if hl_keywords == [] and query_param.mode in ["global", "hybrid"]: - logger.warning("high_level_keywords is empty") - return PROMPTS["fail_response"] - else: - hl_keywords = ", ".join(hl_keywords) + logger.warning("high_level_keywords is empty, switching from %s mode to local mode", query_param.mode) + query_param.mode = "local" + + ll_keywords = ", ".join(ll_keywords) if ll_keywords else "" + hl_keywords = ", ".join(hl_keywords) if hl_keywords else "" + + logger.info("Using %s mode for query processing", query_param.mode) # Build context keywords = [ll_keywords, hl_keywords] @@ -596,78 +597,44 @@ async def _build_query_context( # ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" # hl_entities_context, hl_relations_context, hl_text_units_context = "", "", "" - ll_kewwords, hl_keywrds = query[0], query[1] - if query_param.mode in ["local", "hybrid"]: - if ll_kewwords == "": - ll_entities_context, ll_relations_context, ll_text_units_context = ( - "", - "", - "", - ) - warnings.warn( - "Low Level context is None. Return empty Low entity/relationship/source" - ) - query_param.mode = "global" - else: - ( - ll_entities_context, - ll_relations_context, - ll_text_units_context, - ) = await _get_node_data( - ll_kewwords, - knowledge_graph_inst, - entities_vdb, - text_chunks_db, - query_param, - ) - if query_param.mode in ["global", "hybrid"]: - if hl_keywrds == "": - hl_entities_context, hl_relations_context, hl_text_units_context = ( - "", - "", - "", - ) - warnings.warn( - "High Level context is None. Return empty High entity/relationship/source" - ) - query_param.mode = "local" - else: - ( - hl_entities_context, - hl_relations_context, - hl_text_units_context, - ) = await _get_edge_data( - hl_keywrds, - knowledge_graph_inst, - relationships_vdb, - text_chunks_db, - query_param, - ) - if ( - hl_entities_context == "" - and hl_relations_context == "" - and hl_text_units_context == "" - ): - logger.warn("No high level context found. Switching to local mode.") - query_param.mode = "local" - if query_param.mode == "hybrid": + ll_keywords, hl_keywords = query[0], query[1] + + if query_param.mode == "local": + entities_context, relations_context, text_units_context = await _get_node_data( + ll_keywords, + knowledge_graph_inst, + entities_vdb, + text_chunks_db, + query_param, + ) + elif query_param.mode == "global": + entities_context, relations_context, text_units_context = await _get_edge_data( + hl_keywords, + knowledge_graph_inst, + relationships_vdb, + text_chunks_db, + query_param, + ) + else: # hybrid mode + ll_entities_context, ll_relations_context, ll_text_units_context = await _get_node_data( + ll_keywords, + knowledge_graph_inst, + entities_vdb, + text_chunks_db, + query_param, + ) + hl_entities_context, hl_relations_context, hl_text_units_context = await _get_edge_data( + hl_keywords, + knowledge_graph_inst, + relationships_vdb, + text_chunks_db, + query_param, + ) entities_context, relations_context, text_units_context = combine_contexts( [hl_entities_context, ll_entities_context], [hl_relations_context, ll_relations_context], [hl_text_units_context, ll_text_units_context], ) - elif query_param.mode == "local": - entities_context, relations_context, text_units_context = ( - ll_entities_context, - ll_relations_context, - ll_text_units_context, - ) - elif query_param.mode == "global": - entities_context, relations_context, text_units_context = ( - hl_entities_context, - hl_relations_context, - hl_text_units_context, - ) return f""" -----Entities----- ```csv From 536d6f2283815fedb2c423010504fb12fc440055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A5=E7=9F=B3=E6=B8=8A?= Date: Tue, 7 Jan 2025 00:28:15 +0800 Subject: [PATCH 02/38] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AD=97=E7=AC=A6?= =?UTF-8?q?=E5=88=86=E5=89=B2=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=9C=A8=E2=80=9C?= =?UTF-8?q?insert=E2=80=9D=E5=87=BD=E6=95=B0=E4=B8=AD=E5=A6=82=E6=9E=9C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=8F=82=E6=95=B0split=5Fby=5Fcharacter?= =?UTF-8?q?=EF=BC=8C=E5=88=99=E4=BC=9A=E6=8C=89=E7=85=A7split=5Fby=5Fchara?= =?UTF-8?q?cter=E8=BF=9B=E8=A1=8C=E5=AD=97=E7=AC=A6=E5=88=86=E5=89=B2?= =?UTF-8?q?=EF=BC=8C=E6=AD=A4=E6=97=B6=E5=A6=82=E6=9E=9C=E6=AF=8F=E4=B8=AA?= =?UTF-8?q?=E5=88=86=E5=89=B2=E5=90=8E=E7=9A=84chunk=E7=9A=84tokens?= =?UTF-8?q?=E5=A4=A7=E4=BA=8Emax=5Ftoken=5Fsize=EF=BC=8C=E5=88=99=E4=BC=9A?= =?UTF-8?q?=E7=BB=A7=E7=BB=AD=E6=8C=89token=5Fsize=E5=88=86=E5=89=B2?= =?UTF-8?q?=EF=BC=88todo=EF=BC=9A=E8=80=83=E8=99=91=E5=AD=97=E7=AC=A6?= =?UTF-8?q?=E5=88=86=E5=89=B2=E5=90=8E=E8=BF=87=E7=9F=AD=E7=9A=84chunk?= =?UTF-8?q?=E5=A4=84=E7=90=86=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/lightrag.py | 41 ++++--- lightrag/operate.py | 276 +++++++++++++++++++++++-------------------- 2 files changed, 171 insertions(+), 146 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cbe49da2..47d64ac0 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -45,6 +45,7 @@ from .storage import ( from .prompt import GRAPH_FIELD_SEP + # future KG integrations # from .kg.ArangoDB_impl import ( @@ -167,7 +168,7 @@ class LightRAG: # LLM llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' + llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = 32768 llm_model_max_async: int = 16 llm_model_kwargs: dict = field(default_factory=dict) @@ -267,7 +268,7 @@ class LightRAG: self.llm_model_func, hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -313,15 +314,16 @@ class LightRAG: "JsonDocStatusStorage": JsonDocStatusStorage, } - def insert(self, string_or_strings): + def insert(self, string_or_strings, split_by_character=None): loop = always_get_an_event_loop() - return loop.run_until_complete(self.ainsert(string_or_strings)) + return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character)) - async def ainsert(self, string_or_strings): + async def ainsert(self, string_or_strings, split_by_character): """Insert documents with checkpoint support Args: string_or_strings: Single document string or list of document strings + split_by_character: if split_by_character is not None, split the string by character """ if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] @@ -355,10 +357,10 @@ class LightRAG: # Process documents in batches batch_size = self.addon_params.get("insert_batch_size", 10) for i in range(0, len(new_docs), batch_size): - batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + batch_docs = dict(list(new_docs.items())[i: i + batch_size]) for doc_id, doc in tqdm_async( - batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}" + batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}" ): try: # Update status to processing @@ -379,6 +381,7 @@ class LightRAG: } for dp in chunking_by_token_size( doc["content"], + split_by_character=split_by_character, overlap_token_size=self.chunk_overlap_token_size, max_token_size=self.chunk_token_size, tiktoken_model=self.tiktoken_model_name, @@ -545,7 +548,7 @@ class LightRAG: # Check if nodes exist in the knowledge graph for need_insert_id in [src_id, tgt_id]: if not ( - await self.chunk_entity_relation_graph.has_node(need_insert_id) + await self.chunk_entity_relation_graph.has_node(need_insert_id) ): await self.chunk_entity_relation_graph.upsert_node( need_insert_id, @@ -594,9 +597,9 @@ class LightRAG: "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], } for dp in all_relationships_data } @@ -621,7 +624,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -637,7 +640,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -656,7 +659,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -897,7 +900,7 @@ class LightRAG: dp for dp in self.entities_vdb.client_storage["data"] if chunk_id - in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) ] if entities_with_chunk: logger.error( @@ -909,7 +912,7 @@ class LightRAG: dp for dp in self.relationships_vdb.client_storage["data"] if chunk_id - in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) ] if relations_with_chunk: logger.error( @@ -926,7 +929,7 @@ class LightRAG: return asyncio.run(self.adelete_by_doc_id(doc_id)) async def get_entity_info( - self, entity_name: str, include_vector_data: bool = False + self, entity_name: str, include_vector_data: bool = False ): """Get detailed information of an entity @@ -977,7 +980,7 @@ class LightRAG: tracemalloc.stop() async def get_relation_info( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False + self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): """Get detailed information of a relationship @@ -1019,7 +1022,7 @@ class LightRAG: return result def get_relation_info_sync( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False + self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): """Synchronous version of getting relationship information diff --git a/lightrag/operate.py b/lightrag/operate.py index b2c4d215..e8f0df65 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -34,30 +34,52 @@ import time def chunking_by_token_size( - content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" + content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" ): tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) results = [] - for index, start in enumerate( - range(0, len(tokens), max_token_size - overlap_token_size) - ): - chunk_content = decode_tokens_by_tiktoken( - tokens[start : start + max_token_size], model_name=tiktoken_model - ) - results.append( - { - "tokens": min(max_token_size, len(tokens) - start), - "content": chunk_content.strip(), - "chunk_order_index": index, - } - ) + if split_by_character: + raw_chunks = content.split(split_by_character) + new_chunks = [] + for chunk in raw_chunks: + _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) + if len(_tokens) > max_token_size: + for start in range(0, len(_tokens), max_token_size - overlap_token_size): + chunk_content = decode_tokens_by_tiktoken( + _tokens[start: start + max_token_size], model_name=tiktoken_model + ) + new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content)) + else: + new_chunks.append((len(_tokens), chunk)) + for index, (_len, chunk) in enumerate(new_chunks): + results.append( + { + "tokens": _len, + "content": chunk.strip(), + "chunk_order_index": index, + } + ) + else: + for index, start in enumerate( + range(0, len(tokens), max_token_size - overlap_token_size) + ): + chunk_content = decode_tokens_by_tiktoken( + tokens[start: start + max_token_size], model_name=tiktoken_model + ) + results.append( + { + "tokens": min(max_token_size, len(tokens) - start), + "content": chunk_content.strip(), + "chunk_order_index": index, + } + ) return results async def _handle_entity_relation_summary( - entity_or_relation_name: str, - description: str, - global_config: dict, + entity_or_relation_name: str, + description: str, + global_config: dict, ) -> str: use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] @@ -86,8 +108,8 @@ async def _handle_entity_relation_summary( async def _handle_single_entity_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], + chunk_key: str, ): if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None @@ -107,8 +129,8 @@ async def _handle_single_entity_extraction( async def _handle_single_relationship_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], + chunk_key: str, ): if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None @@ -134,10 +156,10 @@ async def _handle_single_relationship_extraction( async def _merge_nodes_then_upsert( - entity_name: str, - nodes_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + entity_name: str, + nodes_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + global_config: dict, ): already_entity_types = [] already_source_ids = [] @@ -181,11 +203,11 @@ async def _merge_nodes_then_upsert( async def _merge_edges_then_upsert( - src_id: str, - tgt_id: str, - edges_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + src_id: str, + tgt_id: str, + edges_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + global_config: dict, ): already_weights = [] already_source_ids = [] @@ -248,12 +270,12 @@ async def _merge_edges_then_upsert( async def extract_entities( - chunks: dict[str, TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, - entity_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - global_config: dict, - llm_response_cache: BaseKVStorage = None, + chunks: dict[str, TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + global_config: dict, + llm_response_cache: BaseKVStorage = None, ) -> Union[BaseGraphStorage, None]: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -305,13 +327,13 @@ async def extract_entities( already_relations = 0 async def _user_llm_func_with_cache( - input_text: str, history_messages: list[dict[str, str]] = None + input_text: str, history_messages: list[dict[str, str]] = None ) -> str: if enable_llm_cache_for_entity_extract and llm_response_cache: need_to_restore = False if ( - global_config["embedding_cache_config"] - and global_config["embedding_cache_config"]["enabled"] + global_config["embedding_cache_config"] + and global_config["embedding_cache_config"]["enabled"] ): new_config = global_config.copy() new_config["embedding_cache_config"] = None @@ -413,7 +435,7 @@ async def extract_entities( already_relations += len(maybe_edges) now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) - ] + ] print( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", end="", @@ -423,10 +445,10 @@ async def extract_entities( results = [] for result in tqdm_async( - asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), - total=len(ordered_chunks), - desc="Extracting entities from chunks", - unit="chunk", + asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), + total=len(ordered_chunks), + desc="Extracting entities from chunks", + unit="chunk", ): results.append(await result) @@ -440,32 +462,32 @@ async def extract_entities( logger.info("Inserting entities into storage...") all_entities_data = [] for result in tqdm_async( - asyncio.as_completed( - [ - _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ), - total=len(maybe_nodes), - desc="Inserting entities", - unit="entity", + asyncio.as_completed( + [ + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ), + total=len(maybe_nodes), + desc="Inserting entities", + unit="entity", ): all_entities_data.append(await result) logger.info("Inserting relationships into storage...") all_relationships_data = [] for result in tqdm_async( - asyncio.as_completed( - [ - _merge_edges_then_upsert( - k[0], k[1], v, knowledge_graph_inst, global_config - ) - for k, v in maybe_edges.items() - ] - ), - total=len(maybe_edges), - desc="Inserting relationships", - unit="relationship", + asyncio.as_completed( + [ + _merge_edges_then_upsert( + k[0], k[1], v, knowledge_graph_inst, global_config + ) + for k, v in maybe_edges.items() + ] + ), + total=len(maybe_edges), + desc="Inserting relationships", + unit="relationship", ): all_relationships_data.append(await result) @@ -496,9 +518,9 @@ async def extract_entities( "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], "metadata": { "created_at": dp.get("metadata", {}).get("created_at", time.time()) }, @@ -511,14 +533,14 @@ async def extract_entities( async def kg_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: # Handle cache use_model_func = global_config["llm_model_func"] @@ -638,12 +660,12 @@ async def kg_query( async def _build_query_context( - query: list, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + query: list, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): # ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" # hl_entities_context, hl_relations_context, hl_text_units_context = "", "", "" @@ -696,9 +718,9 @@ async def _build_query_context( query_param, ) if ( - hl_entities_context == "" - and hl_relations_context == "" - and hl_text_units_context == "" + hl_entities_context == "" + and hl_relations_context == "" + and hl_text_units_context == "" ): logger.warn("No high level context found. Switching to local mode.") query_param.mode = "local" @@ -737,11 +759,11 @@ async def _build_query_context( async def _get_node_data( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): # get similar entities results = await entities_vdb.query(query, top_k=query_param.top_k) @@ -828,10 +850,10 @@ async def _get_node_data( async def _find_most_related_text_unit_from_entities( - node_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage[TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -871,8 +893,8 @@ async def _find_most_related_text_unit_from_entities( if this_edges: for e in this_edges: if ( - e[1] in all_one_hop_text_units_lookup - and c_id in all_one_hop_text_units_lookup[e[1]] + e[1] in all_one_hop_text_units_lookup + and c_id in all_one_hop_text_units_lookup[e[1]] ): all_text_units_lookup[c_id]["relation_counts"] += 1 @@ -902,9 +924,9 @@ async def _find_most_related_text_unit_from_entities( async def _find_most_related_edges_from_entities( - node_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, ): all_related_edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] @@ -942,11 +964,11 @@ async def _find_most_related_edges_from_entities( async def _get_edge_data( - keywords, - knowledge_graph_inst: BaseGraphStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + keywords, + knowledge_graph_inst: BaseGraphStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): results = await relationships_vdb.query(keywords, top_k=query_param.top_k) @@ -1044,9 +1066,9 @@ async def _get_edge_data( async def _find_most_related_entities_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, ): entity_names = [] seen = set() @@ -1081,10 +1103,10 @@ async def _find_most_related_entities_from_relationships( async def _find_related_text_unit_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage[TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -1150,12 +1172,12 @@ def combine_contexts(entities, relationships, sources): async def naive_query( - query, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ): # Handle cache use_model_func = global_config["llm_model_func"] @@ -1213,7 +1235,7 @@ async def naive_query( if len(response) > len(sys_prompt): response = ( - response[len(sys_prompt) :] + response[len(sys_prompt):] .replace(sys_prompt, "") .replace("user", "") .replace("model", "") @@ -1241,15 +1263,15 @@ async def naive_query( async def mix_kg_vector_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: """ Hybrid retrieval implementation combining knowledge graph and vector search. @@ -1274,7 +1296,7 @@ async def mix_kg_vector_query( # Reuse keyword extraction logic from kg_query example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len( - PROMPTS["keywords_extraction_examples"] + PROMPTS["keywords_extraction_examples"] ): examples = "\n".join( PROMPTS["keywords_extraction_examples"][: int(example_number)] From 196350b75bef6c07eef0ee19cc55c45dcd1784d5 Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Tue, 7 Jan 2025 07:02:37 +0800 Subject: [PATCH 03/38] Revise the readme to fix the broken link. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ed2a7789..ea8d0a97 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News -- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage). +- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage). - [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise. - [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author. From 3bbd3ee1b232cf1335617a5f4308651b295061b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A5=E7=9F=B3=E6=B8=8A?= Date: Tue, 7 Jan 2025 13:45:18 +0800 Subject: [PATCH 04/38] =?UTF-8?q?=E5=9C=A8Mac=E7=AB=AFtorch~=3D2.5.1+cu121?= =?UTF-8?q?=E4=BC=9A=E5=AF=BC=E8=87=B4=E6=9C=AC=E5=9C=B0=E5=AE=89=E8=A3=85?= =?UTF-8?q?=E6=97=B6=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 79249e7e..dd3c4cf3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,8 @@ tenacity~=9.0.0 # LLM packages tiktoken~=0.8.0 -torch~=2.5.1+cu121 +# torch~=2.5.1+cu121 +torch~=2.5.1 tqdm~=4.67.1 transformers~=4.47.1 xxhash From 290744d77040799c2c238524ad39cb1355c1182f Mon Sep 17 00:00:00 2001 From: LarFii <834462287@qq.com> Date: Tue, 7 Jan 2025 16:04:46 +0800 Subject: [PATCH 05/38] fix requirements.txt --- requirements.txt | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/requirements.txt b/requirements.txt index 79249e7e..e81473ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,38 +1,38 @@ accelerate -aioboto3~=13.3.0 -aiofiles~=24.1.0 -aiohttp~=3.11.11 -asyncpg~=0.30.0 +aioboto3 +aiofiles +aiohttp +asyncpg # database packages graspologic gremlinpython hnswlib nano-vectordb -neo4j~=5.27.0 -networkx~=3.2.1 +neo4j +networkx -numpy~=2.2.0 -ollama~=0.4.4 -openai~=1.58.1 +numpy +ollama +openai oracledb -psycopg-pool~=3.2.4 -psycopg[binary,pool]~=3.2.3 -pydantic~=2.10.4 +psycopg-pool +psycopg[binary,pool] +pydantic pymilvus pymongo pymysql -python-dotenv~=1.0.1 -pyvis~=0.3.2 -setuptools~=70.0.0 +python-dotenv +pyvis +setuptools # lmdeploy[all] -sqlalchemy~=2.0.36 -tenacity~=9.0.0 +sqlalchemy +tenacity # LLM packages -tiktoken~=0.8.0 -torch~=2.5.1+cu121 -tqdm~=4.67.1 -transformers~=4.47.1 -xxhash +tiktoken +torch +tqdm +transformers +xxhash \ No newline at end of file From 9ef4fe667aeb0ac4b303de698fcdef3ae4fb1c20 Mon Sep 17 00:00:00 2001 From: LarFii <834462287@qq.com> Date: Tue, 7 Jan 2025 16:18:19 +0800 Subject: [PATCH 06/38] rename --- contributor-readme.MD => contributor-README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename contributor-readme.MD => contributor-README.md (100%) diff --git a/contributor-readme.MD b/contributor-README.md similarity index 100% rename from contributor-readme.MD rename to contributor-README.md From 79d705071027e15a57c54cc64bc07d2dda246498 Mon Sep 17 00:00:00 2001 From: LarFii <834462287@qq.com> Date: Tue, 7 Jan 2025 16:21:54 +0800 Subject: [PATCH 07/38] fix linting errors --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e81473ea..48c25ff8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,4 @@ tiktoken torch tqdm transformers -xxhash \ No newline at end of file +xxhash From 6b19401dc6f0a27597f15990bd86206409feb540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A5=E7=9F=B3=E6=B8=8A?= Date: Tue, 7 Jan 2025 16:26:12 +0800 Subject: [PATCH 08/38] chunk split retry --- lightrag/lightrag.py | 34 +- lightrag/operate.py | 247 ++++++++------- test.ipynb | 740 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 886 insertions(+), 135 deletions(-) create mode 100644 test.ipynb diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 47d64ac0..7496d736 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -268,7 +268,7 @@ class LightRAG: self.llm_model_func, hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -316,7 +316,9 @@ class LightRAG: def insert(self, string_or_strings, split_by_character=None): loop = always_get_an_event_loop() - return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character)) + return loop.run_until_complete( + self.ainsert(string_or_strings, split_by_character) + ) async def ainsert(self, string_or_strings, split_by_character): """Insert documents with checkpoint support @@ -357,10 +359,10 @@ class LightRAG: # Process documents in batches batch_size = self.addon_params.get("insert_batch_size", 10) for i in range(0, len(new_docs), batch_size): - batch_docs = dict(list(new_docs.items())[i: i + batch_size]) + batch_docs = dict(list(new_docs.items())[i : i + batch_size]) for doc_id, doc in tqdm_async( - batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}" + batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}" ): try: # Update status to processing @@ -548,7 +550,7 @@ class LightRAG: # Check if nodes exist in the knowledge graph for need_insert_id in [src_id, tgt_id]: if not ( - await self.chunk_entity_relation_graph.has_node(need_insert_id) + await self.chunk_entity_relation_graph.has_node(need_insert_id) ): await self.chunk_entity_relation_graph.upsert_node( need_insert_id, @@ -597,9 +599,9 @@ class LightRAG: "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], } for dp in all_relationships_data } @@ -624,7 +626,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -640,7 +642,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -659,7 +661,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -900,7 +902,7 @@ class LightRAG: dp for dp in self.entities_vdb.client_storage["data"] if chunk_id - in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) ] if entities_with_chunk: logger.error( @@ -912,7 +914,7 @@ class LightRAG: dp for dp in self.relationships_vdb.client_storage["data"] if chunk_id - in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) ] if relations_with_chunk: logger.error( @@ -929,7 +931,7 @@ class LightRAG: return asyncio.run(self.adelete_by_doc_id(doc_id)) async def get_entity_info( - self, entity_name: str, include_vector_data: bool = False + self, entity_name: str, include_vector_data: bool = False ): """Get detailed information of an entity @@ -980,7 +982,7 @@ class LightRAG: tracemalloc.stop() async def get_relation_info( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False + self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): """Get detailed information of a relationship @@ -1022,7 +1024,7 @@ class LightRAG: return result def get_relation_info_sync( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False + self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): """Synchronous version of getting relationship information diff --git a/lightrag/operate.py b/lightrag/operate.py index e8f0df65..1128b41c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -34,7 +34,11 @@ import time def chunking_by_token_size( - content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" + content: str, + split_by_character=None, + overlap_token_size=128, + max_token_size=1024, + tiktoken_model="gpt-4o", ): tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) results = [] @@ -44,11 +48,16 @@ def chunking_by_token_size( for chunk in raw_chunks: _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) if len(_tokens) > max_token_size: - for start in range(0, len(_tokens), max_token_size - overlap_token_size): + for start in range( + 0, len(_tokens), max_token_size - overlap_token_size + ): chunk_content = decode_tokens_by_tiktoken( - _tokens[start: start + max_token_size], model_name=tiktoken_model + _tokens[start : start + max_token_size], + model_name=tiktoken_model, + ) + new_chunks.append( + (min(max_token_size, len(_tokens) - start), chunk_content) ) - new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content)) else: new_chunks.append((len(_tokens), chunk)) for index, (_len, chunk) in enumerate(new_chunks): @@ -61,10 +70,10 @@ def chunking_by_token_size( ) else: for index, start in enumerate( - range(0, len(tokens), max_token_size - overlap_token_size) + range(0, len(tokens), max_token_size - overlap_token_size) ): chunk_content = decode_tokens_by_tiktoken( - tokens[start: start + max_token_size], model_name=tiktoken_model + tokens[start : start + max_token_size], model_name=tiktoken_model ) results.append( { @@ -77,9 +86,9 @@ def chunking_by_token_size( async def _handle_entity_relation_summary( - entity_or_relation_name: str, - description: str, - global_config: dict, + entity_or_relation_name: str, + description: str, + global_config: dict, ) -> str: use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] @@ -108,8 +117,8 @@ async def _handle_entity_relation_summary( async def _handle_single_entity_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], + chunk_key: str, ): if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None @@ -129,8 +138,8 @@ async def _handle_single_entity_extraction( async def _handle_single_relationship_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], + chunk_key: str, ): if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None @@ -156,10 +165,10 @@ async def _handle_single_relationship_extraction( async def _merge_nodes_then_upsert( - entity_name: str, - nodes_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + entity_name: str, + nodes_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + global_config: dict, ): already_entity_types = [] already_source_ids = [] @@ -203,11 +212,11 @@ async def _merge_nodes_then_upsert( async def _merge_edges_then_upsert( - src_id: str, - tgt_id: str, - edges_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + src_id: str, + tgt_id: str, + edges_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + global_config: dict, ): already_weights = [] already_source_ids = [] @@ -270,12 +279,12 @@ async def _merge_edges_then_upsert( async def extract_entities( - chunks: dict[str, TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, - entity_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - global_config: dict, - llm_response_cache: BaseKVStorage = None, + chunks: dict[str, TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + global_config: dict, + llm_response_cache: BaseKVStorage = None, ) -> Union[BaseGraphStorage, None]: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -327,13 +336,13 @@ async def extract_entities( already_relations = 0 async def _user_llm_func_with_cache( - input_text: str, history_messages: list[dict[str, str]] = None + input_text: str, history_messages: list[dict[str, str]] = None ) -> str: if enable_llm_cache_for_entity_extract and llm_response_cache: need_to_restore = False if ( - global_config["embedding_cache_config"] - and global_config["embedding_cache_config"]["enabled"] + global_config["embedding_cache_config"] + and global_config["embedding_cache_config"]["enabled"] ): new_config = global_config.copy() new_config["embedding_cache_config"] = None @@ -435,7 +444,7 @@ async def extract_entities( already_relations += len(maybe_edges) now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) - ] + ] print( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", end="", @@ -445,10 +454,10 @@ async def extract_entities( results = [] for result in tqdm_async( - asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), - total=len(ordered_chunks), - desc="Extracting entities from chunks", - unit="chunk", + asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), + total=len(ordered_chunks), + desc="Extracting entities from chunks", + unit="chunk", ): results.append(await result) @@ -462,32 +471,32 @@ async def extract_entities( logger.info("Inserting entities into storage...") all_entities_data = [] for result in tqdm_async( - asyncio.as_completed( - [ - _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ), - total=len(maybe_nodes), - desc="Inserting entities", - unit="entity", + asyncio.as_completed( + [ + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ), + total=len(maybe_nodes), + desc="Inserting entities", + unit="entity", ): all_entities_data.append(await result) logger.info("Inserting relationships into storage...") all_relationships_data = [] for result in tqdm_async( - asyncio.as_completed( - [ - _merge_edges_then_upsert( - k[0], k[1], v, knowledge_graph_inst, global_config - ) - for k, v in maybe_edges.items() - ] - ), - total=len(maybe_edges), - desc="Inserting relationships", - unit="relationship", + asyncio.as_completed( + [ + _merge_edges_then_upsert( + k[0], k[1], v, knowledge_graph_inst, global_config + ) + for k, v in maybe_edges.items() + ] + ), + total=len(maybe_edges), + desc="Inserting relationships", + unit="relationship", ): all_relationships_data.append(await result) @@ -518,9 +527,9 @@ async def extract_entities( "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], "metadata": { "created_at": dp.get("metadata", {}).get("created_at", time.time()) }, @@ -533,14 +542,14 @@ async def extract_entities( async def kg_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: # Handle cache use_model_func = global_config["llm_model_func"] @@ -660,12 +669,12 @@ async def kg_query( async def _build_query_context( - query: list, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + query: list, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): # ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" # hl_entities_context, hl_relations_context, hl_text_units_context = "", "", "" @@ -718,9 +727,9 @@ async def _build_query_context( query_param, ) if ( - hl_entities_context == "" - and hl_relations_context == "" - and hl_text_units_context == "" + hl_entities_context == "" + and hl_relations_context == "" + and hl_text_units_context == "" ): logger.warn("No high level context found. Switching to local mode.") query_param.mode = "local" @@ -759,11 +768,11 @@ async def _build_query_context( async def _get_node_data( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): # get similar entities results = await entities_vdb.query(query, top_k=query_param.top_k) @@ -850,10 +859,10 @@ async def _get_node_data( async def _find_most_related_text_unit_from_entities( - node_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage[TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -893,8 +902,8 @@ async def _find_most_related_text_unit_from_entities( if this_edges: for e in this_edges: if ( - e[1] in all_one_hop_text_units_lookup - and c_id in all_one_hop_text_units_lookup[e[1]] + e[1] in all_one_hop_text_units_lookup + and c_id in all_one_hop_text_units_lookup[e[1]] ): all_text_units_lookup[c_id]["relation_counts"] += 1 @@ -924,9 +933,9 @@ async def _find_most_related_text_unit_from_entities( async def _find_most_related_edges_from_entities( - node_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, ): all_related_edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] @@ -964,11 +973,11 @@ async def _find_most_related_edges_from_entities( async def _get_edge_data( - keywords, - knowledge_graph_inst: BaseGraphStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + keywords, + knowledge_graph_inst: BaseGraphStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): results = await relationships_vdb.query(keywords, top_k=query_param.top_k) @@ -1066,9 +1075,9 @@ async def _get_edge_data( async def _find_most_related_entities_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, ): entity_names = [] seen = set() @@ -1103,10 +1112,10 @@ async def _find_most_related_entities_from_relationships( async def _find_related_text_unit_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage[TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -1172,12 +1181,12 @@ def combine_contexts(entities, relationships, sources): async def naive_query( - query, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ): # Handle cache use_model_func = global_config["llm_model_func"] @@ -1235,7 +1244,7 @@ async def naive_query( if len(response) > len(sys_prompt): response = ( - response[len(sys_prompt):] + response[len(sys_prompt) :] .replace(sys_prompt, "") .replace("user", "") .replace("model", "") @@ -1263,15 +1272,15 @@ async def naive_query( async def mix_kg_vector_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: """ Hybrid retrieval implementation combining knowledge graph and vector search. @@ -1296,7 +1305,7 @@ async def mix_kg_vector_query( # Reuse keyword extraction logic from kg_query example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len( - PROMPTS["keywords_extraction_examples"] + PROMPTS["keywords_extraction_examples"] ): examples = "\n".join( PROMPTS["keywords_extraction_examples"][: int(example_number)] diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 00000000..2b9253b4 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,740 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "4b5690db12e34685", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:38:34.174205Z", + "start_time": "2025-01-07T05:38:29.978194Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import logging\n", + "import numpy as np\n", + "from lightrag import LightRAG, QueryParam\n", + "from lightrag.llm import openai_complete_if_cache, openai_embedding\n", + "from lightrag.utils import EmbeddingFunc\n", + "import nest_asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8c8ee7c061bf9159", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:38:37.440083Z", + "start_time": "2025-01-07T05:38:37.437666Z" + } + }, + "outputs": [], + "source": [ + "nest_asyncio.apply()\n", + "WORKING_DIR = \"../llm_rag/paper_db/R000088_test2\"\n", + "logging.basicConfig(format=\"%(levelname)s:%(message)s\", level=logging.INFO)\n", + "if not os.path.exists(WORKING_DIR):\n", + " os.mkdir(WORKING_DIR)\n", + "os.environ[\"doubao_api\"] = \"6b890250-0cf6-4eb1-aa82-9c9d711398a7\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a5009d16e0851dca", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:38:42.594315Z", + "start_time": "2025-01-07T05:38:42.590800Z" + } + }, + "outputs": [], + "source": [ + "async def llm_model_func(\n", + " prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs\n", + ") -> str:\n", + " return await openai_complete_if_cache(\n", + " \"ep-20241218114828-2tlww\",\n", + " prompt,\n", + " system_prompt=system_prompt,\n", + " history_messages=history_messages,\n", + " api_key=os.getenv(\"doubao_api\"),\n", + " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", + " **kwargs,\n", + " )\n", + "\n", + "\n", + "async def embedding_func(texts: list[str]) -> np.ndarray:\n", + " return await openai_embedding(\n", + " texts,\n", + " model=\"ep-20241231173413-pgjmk\",\n", + " api_key=os.getenv(\"doubao_api\"),\n", + " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "397fcad24ce4d0ed", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:38:44.016901Z", + "start_time": "2025-01-07T05:38:44.006291Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:lightrag:Logger initialized for working directory: ../llm_rag/paper_db/R000088_test2\n", + "INFO:lightrag:Load KV llm_response_cache with 0 data\n", + "INFO:lightrag:Load KV full_docs with 0 data\n", + "INFO:lightrag:Load KV text_chunks with 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../llm_rag/paper_db/R000088_test2/vdb_entities.json'} 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../llm_rag/paper_db/R000088_test2/vdb_relationships.json'} 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../llm_rag/paper_db/R000088_test2/vdb_chunks.json'} 0 data\n", + "INFO:lightrag:Loaded document status storage with 0 records\n" + ] + } + ], + "source": [ + "rag = LightRAG(\n", + " working_dir=WORKING_DIR,\n", + " llm_model_func=llm_model_func,\n", + " embedding_func=EmbeddingFunc(\n", + " embedding_dim=4096, max_token_size=8192, func=embedding_func\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1dc3603677f7484d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:38:47.509111Z", + "start_time": "2025-01-07T05:38:47.501997Z" + } + }, + "outputs": [], + "source": [ + "with open(\n", + " \"../llm_rag/example/R000088/auto/R000088_full_txt.md\", \"r\", encoding=\"utf-8\"\n", + ") as f:\n", + " content = f.read()\n", + "\n", + "\n", + "async def embedding_func(texts: list[str]) -> np.ndarray:\n", + " return await openai_embedding(\n", + " texts,\n", + " model=\"ep-20241231173413-pgjmk\",\n", + " api_key=os.getenv(\"doubao_api\"),\n", + " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", + " )\n", + "\n", + "\n", + "async def get_embedding_dim():\n", + " test_text = [\"This is a test sentence.\"]\n", + " embedding = await embedding_func(test_text)\n", + " embedding_dim = embedding.shape[1]\n", + " return embedding_dim" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6844202606acfbe5", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:38:50.666764Z", + "start_time": "2025-01-07T05:38:50.247712Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n" + ] + } + ], + "source": [ + "embedding_dimension = await get_embedding_dim()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d6273839d9681403", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:42:33.085507Z", + "start_time": "2025-01-07T05:38:56.789348Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:lightrag:Processing 1 new unique documents\n", + "Processing batch 1: 0%| | 0/1 [00:00标签中,针对每个问题详细分析你的思考过程。然后在<回答>标签中给出所有问题的最终答案。\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7a6491385b050095", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-07T05:43:24.751628Z", + "start_time": "2025-01-07T05:42:50.865679Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/chat/completions \"HTTP/1.1 200 OK\"\n", + "INFO:lightrag:kw_prompt result:\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"high_level_keywords\": [\"英文学术研究论文分析\", \"关键信息提取\", \"深入分析\"],\n", + " \"low_level_keywords\": [\"研究队列\", \"队列名称\", \"队列开展国家\", \"性别分布\", \"年龄分布\", \"队列研究时间线\", \"实际参与研究人数\"]\n", + "}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n", + "INFO:lightrag:Local query uses 60 entites, 38 relations, 6 text units\n", + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n", + "INFO:lightrag:Global query uses 72 entites, 60 relations, 4 text units\n", + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<分析>\n", + "- **分析对象来自哪些研究队列及是单独分析还是联合分析**:\n", + " 通过查找论文内容,发现文中提到“This is a combined analysis of data from 2 randomized, double-blind, placebo-controlled clinical trials (Norwegian Vitamin [NORVIT] trial15 and Western Norway B Vitamin Intervention Trial [WENBIT]16)”,明确是对两个队列的数据进行联合分析,队列名称分别为“Norwegian Vitamin (NORVIT) trial”和“Western Norway B Vitamin Intervention Trial (WENBIT)”。\n", + "- **队列开展的国家**:\n", + " 文中多次提及研究在挪威进行,如“combined analyses and extended follow-up of 2 vitamin B intervention trials among patients with ischemic heart disease in Norway”,所以确定研究开展的国家是挪威。\n", + "- **队列研究对象的性别分布**:\n", + " 从“Mean (SD) age was 62.3 (11.0) years and 23.5% of participants were women”可知,研究对象包含男性和女性,即全体。\n", + "- **队列收集结束时研究对象年龄分布**:\n", + " 已知“Mean (SD) age was 62.3 (11.0) years”是基线时年龄信息,“Median (interquartile range) duration of extended follow-up through December 31, 2007, was 78 (61 - 90) months”,由于随访的中位时间是78个月(约6.5年),所以可推算队列收集结束时研究对象年龄均值约为62.3 + 6.5 = 68.8岁(标准差仍为11.0年)。\n", + "- **队列研究时间线**:\n", + " 根据“2 randomized, double-blind, placebo-controlled clinical trials (Norwegian Vitamin [NORVIT] trial15 and Western Norway B Vitamin Intervention Trial [WENBIT]16) conducted between 1998 and 2005, and an observational posttrial follow-up through December 31, 2007”可知,队列开始收集信息时间为1998年,结束时间为2007年12月31日。\n", + "- **队列结束时实际参与研究人数**:\n", + " 由“A total of 6837 individuals were included in the combined analyses, of whom 6261 (91.6%) participated in posttrial follow-up”可知,队列结束时实际参与研究人数为6261人。\n", + "\n", + "\n", + "<回答>\n", + "- 分析对象来自“Norwegian Vitamin (NORVIT) trial”和“Western Norway B Vitamin Intervention Trial (WENBIT)”两个研究队列,文中是对这两个队列的数据进行联合分析。\n", + "- 队列开展的国家是挪威。\n", + "- 队列研究对象的性别分布为全体。\n", + "- 队列收集结束时,研究对象年龄分布均值约为68.8岁,标准差为11.0年。\n", + "- 队列研究时间线为1998年开始收集信息/建立队列,2007年12月31日结束。\n", + "- 队列结束时实际参与研究人数是6261人。\n" + ] + } + ], + "source": [ + "print(rag.query(prompt1, param=QueryParam(mode=\"hybrid\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fef9d06983da47af", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 6c78c96854d9ab563a547546dd8652ed59190bd2 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 7 Jan 2025 22:02:34 +0800 Subject: [PATCH 09/38] fix linting errors --- lightrag/operate.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 59e9f648..ce7b0a8a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -4,7 +4,6 @@ import re from tqdm.asyncio import tqdm as tqdm_async from typing import Union from collections import Counter, defaultdict -import warnings from .utils import ( logger, clean_str, @@ -605,10 +604,16 @@ async def kg_query( logger.warning("low_level_keywords and high_level_keywords is empty") return PROMPTS["fail_response"] if ll_keywords == [] and query_param.mode in ["local", "hybrid"]: - logger.warning("low_level_keywords is empty, switching from %s mode to global mode", query_param.mode) + logger.warning( + "low_level_keywords is empty, switching from %s mode to global mode", + query_param.mode, + ) query_param.mode = "global" if hl_keywords == [] and query_param.mode in ["global", "hybrid"]: - logger.warning("high_level_keywords is empty, switching from %s mode to local mode", query_param.mode) + logger.warning( + "high_level_keywords is empty, switching from %s mode to local mode", + query_param.mode, + ) query_param.mode = "local" ll_keywords = ", ".join(ll_keywords) if ll_keywords else "" @@ -699,14 +704,22 @@ async def _build_query_context( query_param, ) else: # hybrid mode - ll_entities_context, ll_relations_context, ll_text_units_context = await _get_node_data( + ( + ll_entities_context, + ll_relations_context, + ll_text_units_context, + ) = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) - hl_entities_context, hl_relations_context, hl_text_units_context = await _get_edge_data( + ( + hl_entities_context, + hl_relations_context, + hl_text_units_context, + ) = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, From a9402513909606c76a2e8d5e040f12ecb8aa4739 Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Tue, 7 Jan 2025 20:57:39 +0530 Subject: [PATCH 10/38] Implement custom chunking feature --- lightrag/lightrag.py | 66 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7496d736..2225b2d1 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -458,6 +458,72 @@ class LightRAG: # Ensure all indexes are updated after each document await self._insert_done() + def insert_custom_chunks(self, full_text: str, text_chunks: list[str]): + loop = always_get_an_event_loop() + return loop.run_until_complete(self.ainsert_custom_chunks(full_text, text_chunks)) + + async def ainsert_custom_chunks(self, full_text: str, text_chunks: list[str]): + + update_storage = False + try: + doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-") + new_docs = { + doc_key: {"content": full_text.strip()} + } + + _add_doc_keys = await self.full_docs.filter_keys([doc_key]) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + if not len(new_docs): + logger.warning("This document is already in the storage.") + return + + update_storage = True + logger.info(f"[New Docs] inserting {len(new_docs)} docs") + + inserting_chunks = {} + for chunk_text in text_chunks: + chunk_text_stripped = chunk_text.strip() + chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-") + + inserting_chunks[chunk_key] = { + "content": chunk_text_stripped, + "full_doc_id": doc_key, + } + + _add_chunk_keys = await self.text_chunks.filter_keys(list(inserting_chunks.keys())) + inserting_chunks = { + k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + } + if not len(inserting_chunks): + logger.warning("All chunks are already in the storage.") + return + + logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") + + await self.chunks_vdb.upsert(inserting_chunks) + + logger.info("[Entity Extraction]...") + maybe_new_kg = await extract_entities( + inserting_chunks, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + global_config=asdict(self), + ) + + if maybe_new_kg is None: + logger.warning("No new entities and relationships found") + return + else: + self.chunk_entity_relation_graph = maybe_new_kg + + await self.full_docs.upsert(new_docs) + await self.text_chunks.upsert(inserting_chunks) + + finally: + if update_storage: + await self._insert_done() + async def _insert_done(self): tasks = [] for storage_inst in [ From 9e7784ab8a642415432c742d8e891f6173886f66 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 8 Jan 2025 18:17:32 +0800 Subject: [PATCH 11/38] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f66fb3ce..6c981d92 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@

- +

From 9565a4663ad8878126f16d667455ca5a22f1d557 Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Thu, 9 Jan 2025 00:39:22 +0530 Subject: [PATCH 12/38] Fix trailing whitespace and formatting issues in lightrag.py --- lightrag/lightrag.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 2225b2d1..6af29aa2 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -460,16 +460,15 @@ class LightRAG: def insert_custom_chunks(self, full_text: str, text_chunks: list[str]): loop = always_get_an_event_loop() - return loop.run_until_complete(self.ainsert_custom_chunks(full_text, text_chunks)) + return loop.run_until_complete( + self.ainsert_custom_chunks(full_text, text_chunks) + ) async def ainsert_custom_chunks(self, full_text: str, text_chunks: list[str]): - update_storage = False try: doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-") - new_docs = { - doc_key: {"content": full_text.strip()} - } + new_docs = {doc_key: {"content": full_text.strip()}} _add_doc_keys = await self.full_docs.filter_keys([doc_key]) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} @@ -484,13 +483,15 @@ class LightRAG: for chunk_text in text_chunks: chunk_text_stripped = chunk_text.strip() chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-") - + inserting_chunks[chunk_key] = { "content": chunk_text_stripped, "full_doc_id": doc_key, } - _add_chunk_keys = await self.text_chunks.filter_keys(list(inserting_chunks.keys())) + _add_chunk_keys = await self.text_chunks.filter_keys( + list(inserting_chunks.keys()) + ) inserting_chunks = { k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } From 65c1450c66a769e9134e900a87706f9bc4ab5a97 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 8 Jan 2025 20:50:22 +0100 Subject: [PATCH 13/38] fixed retro compatibility with ainsert by making split_by_character get a None default value --- lightrag/lightrag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7496d736..362b7275 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -320,7 +320,7 @@ class LightRAG: self.ainsert(string_or_strings, split_by_character) ) - async def ainsert(self, string_or_strings, split_by_character): + async def ainsert(self, string_or_strings, split_by_character=None): """Insert documents with checkpoint support Args: From dd213c95be5c63bc61f399f14612028fd40a4a33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A5=E7=9F=B3=E6=B8=8A?= Date: Thu, 9 Jan 2025 11:55:49 +0800 Subject: [PATCH 14/38] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BB=85=E5=AD=97?= =?UTF-8?q?=E7=AC=A6=E5=88=86=E5=89=B2=E5=8F=82=E6=95=B0=EF=BC=8C=E5=A6=82?= =?UTF-8?q?=E6=9E=9C=E5=BC=80=E5=90=AF=EF=BC=8C=E4=BB=85=E9=87=87=E7=94=A8?= =?UTF-8?q?=E5=AD=97=E7=AC=A6=E5=88=86=E5=89=B2=EF=BC=8C=E4=B8=8D=E5=BC=80?= =?UTF-8?q?=E5=90=AF=EF=BC=8C=E5=9C=A8=E5=88=86=E5=89=B2=E5=AE=8C=E4=BB=A5?= =?UTF-8?q?=E5=90=8E=E5=A6=82=E6=9E=9Cchunk=E8=BF=87=E5=A4=A7=EF=BC=8C?= =?UTF-8?q?=E4=BC=9A=E7=BB=A7=E7=BB=AD=E6=A0=B9=E6=8D=AEtoken=20size?= =?UTF-8?q?=E5=88=86=E5=89=B2=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/test_split_by_character.ipynb | 1296 ++++++++++++++++++++++++ lightrag/lightrag.py | 16 +- lightrag/operate.py | 34 +- test.ipynb | 740 -------------- 4 files changed, 1328 insertions(+), 758 deletions(-) create mode 100644 examples/test_split_by_character.ipynb delete mode 100644 test.ipynb diff --git a/examples/test_split_by_character.ipynb b/examples/test_split_by_character.ipynb new file mode 100644 index 00000000..e8e08b92 --- /dev/null +++ b/examples/test_split_by_character.ipynb @@ -0,0 +1,1296 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "4b5690db12e34685", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:40:58.307102Z", + "start_time": "2025-01-09T03:40:51.935233Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import logging\n", + "import numpy as np\n", + "from lightrag import LightRAG, QueryParam\n", + "from lightrag.llm import openai_complete_if_cache, openai_embedding\n", + "from lightrag.utils import EmbeddingFunc\n", + "import nest_asyncio" + ] + }, + { + "cell_type": "markdown", + "id": "dd17956ec322b361", + "metadata": {}, + "source": "#### split by character" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8c8ee7c061bf9159", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:41:13.961167Z", + "start_time": "2025-01-09T03:41:13.958357Z" + } + }, + "outputs": [], + "source": [ + "nest_asyncio.apply()\n", + "WORKING_DIR = \"../../llm_rag/paper_db/R000088_test1\"\n", + "logging.basicConfig(format=\"%(levelname)s:%(message)s\", level=logging.INFO)\n", + "if not os.path.exists(WORKING_DIR):\n", + " os.mkdir(WORKING_DIR)\n", + "API = os.environ.get(\"DOUBAO_API_KEY\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a5009d16e0851dca", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:41:16.862036Z", + "start_time": "2025-01-09T03:41:16.859306Z" + } + }, + "outputs": [], + "source": [ + "async def llm_model_func(\n", + " prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs\n", + ") -> str:\n", + " return await openai_complete_if_cache(\n", + " \"ep-20241218114828-2tlww\",\n", + " prompt,\n", + " system_prompt=system_prompt,\n", + " history_messages=history_messages,\n", + " api_key=API,\n", + " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", + " **kwargs,\n", + " )\n", + "\n", + "\n", + "async def embedding_func(texts: list[str]) -> np.ndarray:\n", + " return await openai_embedding(\n", + " texts,\n", + " model=\"ep-20241231173413-pgjmk\",\n", + " api_key=API,\n", + " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "397fcad24ce4d0ed", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:41:24.950307Z", + "start_time": "2025-01-09T03:41:24.940353Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:lightrag:Logger initialized for working directory: ../../llm_rag/paper_db/R000088_test1\n", + "INFO:lightrag:Load KV llm_response_cache with 0 data\n", + "INFO:lightrag:Load KV full_docs with 0 data\n", + "INFO:lightrag:Load KV text_chunks with 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../../llm_rag/paper_db/R000088_test1/vdb_entities.json'} 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../../llm_rag/paper_db/R000088_test1/vdb_relationships.json'} 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../../llm_rag/paper_db/R000088_test1/vdb_chunks.json'} 0 data\n", + "INFO:lightrag:Loaded document status storage with 0 records\n" + ] + } + ], + "source": [ + "rag = LightRAG(\n", + " working_dir=WORKING_DIR,\n", + " llm_model_func=llm_model_func,\n", + " embedding_func=EmbeddingFunc(\n", + " embedding_dim=4096, max_token_size=8192, func=embedding_func\n", + " ),\n", + " chunk_token_size=512,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1dc3603677f7484d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:41:37.947456Z", + "start_time": "2025-01-09T03:41:37.941901Z" + } + }, + "outputs": [], + "source": [ + "with open(\n", + " \"../../llm_rag/example/R000088/auto/R000088_full_txt.md\", \"r\", encoding=\"utf-8\"\n", + ") as f:\n", + " content = f.read()\n", + "\n", + "\n", + "async def embedding_func(texts: list[str]) -> np.ndarray:\n", + " return await openai_embedding(\n", + " texts,\n", + " model=\"ep-20241231173413-pgjmk\",\n", + " api_key=API,\n", + " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", + " )\n", + "\n", + "\n", + "async def get_embedding_dim():\n", + " test_text = [\"This is a test sentence.\"]\n", + " embedding = await embedding_func(test_text)\n", + " embedding_dim = embedding.shape[1]\n", + " return embedding_dim" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6844202606acfbe5", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:41:39.608541Z", + "start_time": "2025-01-09T03:41:39.165057Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n" + ] + } + ], + "source": [ + "embedding_dimension = await get_embedding_dim()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d6273839d9681403", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:44:34.295345Z", + "start_time": "2025-01-09T03:41:48.324171Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:lightrag:Processing 1 new unique documents\n", + "Processing batch 1: 0%| | 0/1 [00:00标签中,针对每个问题详细分析你的思考过程。然后在<回答>标签中给出所有问题的最终答案。\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7a6491385b050095", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:45:40.829111Z", + "start_time": "2025-01-09T03:45:13.530298Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n", + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/chat/completions \"HTTP/1.1 200 OK\"\n", + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n", + "INFO:lightrag:Local query uses 5 entites, 12 relations, 3 text units\n", + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n", + "INFO:lightrag:Global query uses 8 entites, 5 relations, 4 text units\n", + "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<分析>\n", + "1. **该文献主要研究的问题是什么?**\n", + " - 思考过程:通过浏览论文内容,查找作者明确阐述研究目的的部分。文中多处提及“Our study was performed to explore whether folic acid treatment was associated with cancer outcomes and all-cause mortality after extended follow-up”,表明作者旨在探究叶酸治疗与癌症结局及全因死亡率之间的关系,尤其是在经过长期随访后。\n", + "2. **该文献采用什么方法进行分析?**\n", + " - 思考过程:寻找描述研究方法和数据分析过程的段落。文中提到“Survival curves were constructed using the Kaplan-Meier method and differences in survival between groups were analyzed using the log-rank test. Estimates of hazard ratios (HRs) with 95% CIs were obtained by using Cox proportional hazards regression models stratified by trial”,可以看出作者使用了Kaplan-Meier法构建生存曲线、log-rank检验分析组间生存差异以及Cox比例风险回归模型估计风险比等方法。\n", + "3. **该文献的主要结论是什么?**\n", + " - 思考过程:定位到论文中总结结论的部分,如“Conclusion Treatment with folic acid plus vitamin $\\mathsf{B}_{12}$ was associated with increased cancer outcomes and all-cause mortality in patients with ischemic heart disease in Norway, where there is no folic acid fortification of foods”,可知作者得出叶酸加维生素$\\mathsf{B}_{12}$治疗与癌症结局和全因死亡率增加有关的结论。\n", + "<回答>\n", + "1. 该文献主要研究的问题是:叶酸治疗与癌症结局及全因死亡率之间的关系,尤其是在经过长期随访后,叶酸治疗是否与癌症结局和全因死亡率相关。\n", + "2. 该文献采用的分析方法包括:使用Kaplan-Meier法构建生存曲线、log-rank检验分析组间生存差异、Cox比例风险回归模型估计风险比等。\n", + "3. 该文献的主要结论是:在挪威没有叶酸强化食品的情况下,叶酸加维生素$\\mathsf{B}_{12}$治疗与缺血性心脏病患者的癌症结局和全因死亡率增加有关。\n", + "\n", + "**参考文献**\n", + "- [VD] In2Norwegianhomocysteine-lowering trialsamongpatientswithischemicheart disease, there was a statistically nonsignificantincreaseincancerincidenceinthe groupsassignedtofolicacidtreatment.15,16 Our study was performed to explore whetherfolicacidtreatmentwasassociatedwithcanceroutcomesandall-cause mortality after extended follow-up.\n", + "- [VD] Survivalcurveswereconstructedusing theKaplan-Meiermethodanddifferences insurvivalbetweengroupswereanalyzed usingthelog-ranktest.Estimatesofhazard ratios (HRs) with $95\\%$ CIs were obtainedbyusingCoxproportionalhazards regressionmodelsstratifiedbytrial.\n", + "- [VD] Conclusion Treatment with folic acid plus vitamin $\\mathsf{B}_{12}$ was associated with increased cancer outcomes and all-cause mortality in patients with ischemic heart disease in Norway, where there is no folic acid fortification of foods.\n" + ] + } + ], + "source": [ + "resp = rag.query(prompt1, param=QueryParam(mode=\"mix\", top_k=5))\n", + "print(resp)" + ] + }, + { + "cell_type": "markdown", + "id": "4e5bfad24cb721a8", + "metadata": {}, + "source": "#### split by character only" + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "44e2992dc95f8ce0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:47:40.988796Z", + "start_time": "2025-01-09T03:47:40.982648Z" + } + }, + "outputs": [], + "source": [ + "WORKING_DIR = \"../../llm_rag/paper_db/R000088_test2\"\n", + "if not os.path.exists(WORKING_DIR):\n", + " os.mkdir(WORKING_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "62c63385d2d973d5", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-09T03:51:39.951329Z", + "start_time": "2025-01-09T03:49:15.218976Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:lightrag:Logger initialized for working directory: ../../llm_rag/paper_db/R000088_test2\n", + "INFO:lightrag:Load KV llm_response_cache with 0 data\n", + "INFO:lightrag:Load KV full_docs with 0 data\n", + "INFO:lightrag:Load KV text_chunks with 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../../llm_rag/paper_db/R000088_test2/vdb_entities.json'} 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../../llm_rag/paper_db/R000088_test2/vdb_relationships.json'} 0 data\n", + "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../../llm_rag/paper_db/R000088_test2/vdb_chunks.json'} 0 data\n", + "INFO:lightrag:Loaded document status storage with 0 records\n", + "INFO:lightrag:Processing 1 new unique documents\n", + "Processing batch 1: 0%| | 0/1 [00:00\n", + "- **该文献主要研究的问题是什么?**\n", + " - **思考过程**:通过浏览论文的标题、摘要、引言等部分,寻找关于研究目的和问题的描述。论文标题为“Cancer Incidence and Mortality After Treatment With Folic Acid and Vitamin B12”,摘要中的“Objective”部分明确指出研究目的是“To evaluate effects of treatment with B vitamins on cancer outcomes and all-cause mortality in 2 randomized controlled trials”。因此,可以确定该文献主要研究的问题是评估B族维生素治疗对两项随机对照试验中癌症结局和全因死亡率的影响。\n", + "- **该文献采用什么方法进行分析?**\n", + " - **思考过程**:在论文的“METHODS”部分详细描述了研究方法。文中提到这是一个对两项随机、双盲、安慰剂对照临床试验(Norwegian Vitamin [NORVIT] trial和Western Norway B Vitamin Intervention Trial [WENBIT])数据的联合分析,并进行了观察性的试验后随访。具体包括对参与者进行分组干预(不同剂量的叶酸、维生素B12、维生素B6或安慰剂),收集临床信息和血样,分析循环B族维生素、同型半胱氨酸和可替宁等指标,并进行基因分型等,还涉及到多种统计分析方法,如计算预期癌症发生率、构建生存曲线、进行Cox比例风险回归模型分析等。\n", + "- **该文献的主要结论是什么?**\n", + " - **思考过程**:在论文的“Results”和“Conclusion”部分寻找主要结论。研究结果表明,在治疗期间,接受叶酸加维生素B12治疗的参与者血清叶酸浓度显著增加,且在后续随访中,该组癌症发病率、癌症死亡率和全因死亡率均有所上升,主要是肺癌发病率增加,而维生素B6治疗未显示出显著影响。结论部分明确指出“Treatment with folic acid plus vitamin $\\mathsf{B}_{12}$ was associated with increased cancer outcomes and all-cause mortality in patients with ischemic heart disease in Norway, where there is no folic acid fortification of foods”。\n", + "\n", + "\n", + "<回答>\n", + "- **主要研究问题**:评估B族维生素治疗对两项随机对照试验中癌症结局和全因死亡率的影响。\n", + "- **研究方法**:采用对两项随机、双盲、安慰剂对照临床试验(Norwegian Vitamin [NORVIT] trial和Western Norway B Vitamin Intervention Trial [WENBIT])数据的联合分析,并进行观察性的试验后随访,涉及分组干预、多种指标检测以及多种统计分析方法。\n", + "- **主要结论**:在挪威(食品中未添加叶酸),对于缺血性心脏病患者,叶酸加维生素B12治疗与癌症结局和全因死亡率的增加有关,而维生素B6治疗未显示出显著影响。\n", + "\n", + "**参考文献**\n", + "- [VD] Cancer Incidence and Mortality After Treatment With Folic Acid and Vitamin B12\n", + "- [VD] METHODS Study Design, Participants, and Study Intervention\n", + "- [VD] RESULTS\n", + "- [VD] Conclusion\n", + "- [VD] Objective To evaluate effects of treatment with B vitamins on cancer outcomes and all-cause mortality in 2 randomized controlled trials.\n" + ] + } + ], + "source": [ + "resp = rag.query(prompt1, param=QueryParam(mode=\"mix\", top_k=5))\n", + "print(resp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ba6fa79a2550d10", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7496d736..b94ff821 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -314,18 +314,25 @@ class LightRAG: "JsonDocStatusStorage": JsonDocStatusStorage, } - def insert(self, string_or_strings, split_by_character=None): + def insert( + self, string_or_strings, split_by_character=None, split_by_character_only=False + ): loop = always_get_an_event_loop() return loop.run_until_complete( - self.ainsert(string_or_strings, split_by_character) + self.ainsert(string_or_strings, split_by_character, split_by_character_only) ) - async def ainsert(self, string_or_strings, split_by_character): + async def ainsert( + self, string_or_strings, split_by_character, split_by_character_only + ): """Insert documents with checkpoint support Args: string_or_strings: Single document string or list of document strings - split_by_character: if split_by_character is not None, split the string by character + split_by_character: if split_by_character is not None, split the string by character, if chunk longer than + chunk_size, split the sub chunk by token size. + split_by_character_only: if split_by_character_only is True, split the string by character only, when + split_by_character is None, this parameter is ignored. """ if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] @@ -384,6 +391,7 @@ class LightRAG: for dp in chunking_by_token_size( doc["content"], split_by_character=split_by_character, + split_by_character_only=split_by_character_only, overlap_token_size=self.chunk_overlap_token_size, max_token_size=self.chunk_token_size, tiktoken_model=self.tiktoken_model_name, diff --git a/lightrag/operate.py b/lightrag/operate.py index 1128b41c..58ae3703 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -36,6 +36,7 @@ import time def chunking_by_token_size( content: str, split_by_character=None, + split_by_character_only=False, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o", @@ -45,21 +46,26 @@ def chunking_by_token_size( if split_by_character: raw_chunks = content.split(split_by_character) new_chunks = [] - for chunk in raw_chunks: - _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) - if len(_tokens) > max_token_size: - for start in range( - 0, len(_tokens), max_token_size - overlap_token_size - ): - chunk_content = decode_tokens_by_tiktoken( - _tokens[start : start + max_token_size], - model_name=tiktoken_model, - ) - new_chunks.append( - (min(max_token_size, len(_tokens) - start), chunk_content) - ) - else: + if split_by_character_only: + for chunk in raw_chunks: + _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) new_chunks.append((len(_tokens), chunk)) + else: + for chunk in raw_chunks: + _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) + if len(_tokens) > max_token_size: + for start in range( + 0, len(_tokens), max_token_size - overlap_token_size + ): + chunk_content = decode_tokens_by_tiktoken( + _tokens[start : start + max_token_size], + model_name=tiktoken_model, + ) + new_chunks.append( + (min(max_token_size, len(_tokens) - start), chunk_content) + ) + else: + new_chunks.append((len(_tokens), chunk)) for index, (_len, chunk) in enumerate(new_chunks): results.append( { diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 2b9253b4..00000000 --- a/test.ipynb +++ /dev/null @@ -1,740 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "4b5690db12e34685", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:38:34.174205Z", - "start_time": "2025-01-07T05:38:29.978194Z" - } - }, - "outputs": [], - "source": [ - "import os\n", - "import logging\n", - "import numpy as np\n", - "from lightrag import LightRAG, QueryParam\n", - "from lightrag.llm import openai_complete_if_cache, openai_embedding\n", - "from lightrag.utils import EmbeddingFunc\n", - "import nest_asyncio" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "8c8ee7c061bf9159", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:38:37.440083Z", - "start_time": "2025-01-07T05:38:37.437666Z" - } - }, - "outputs": [], - "source": [ - "nest_asyncio.apply()\n", - "WORKING_DIR = \"../llm_rag/paper_db/R000088_test2\"\n", - "logging.basicConfig(format=\"%(levelname)s:%(message)s\", level=logging.INFO)\n", - "if not os.path.exists(WORKING_DIR):\n", - " os.mkdir(WORKING_DIR)\n", - "os.environ[\"doubao_api\"] = \"6b890250-0cf6-4eb1-aa82-9c9d711398a7\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "a5009d16e0851dca", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:38:42.594315Z", - "start_time": "2025-01-07T05:38:42.590800Z" - } - }, - "outputs": [], - "source": [ - "async def llm_model_func(\n", - " prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs\n", - ") -> str:\n", - " return await openai_complete_if_cache(\n", - " \"ep-20241218114828-2tlww\",\n", - " prompt,\n", - " system_prompt=system_prompt,\n", - " history_messages=history_messages,\n", - " api_key=os.getenv(\"doubao_api\"),\n", - " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", - " **kwargs,\n", - " )\n", - "\n", - "\n", - "async def embedding_func(texts: list[str]) -> np.ndarray:\n", - " return await openai_embedding(\n", - " texts,\n", - " model=\"ep-20241231173413-pgjmk\",\n", - " api_key=os.getenv(\"doubao_api\"),\n", - " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "397fcad24ce4d0ed", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:38:44.016901Z", - "start_time": "2025-01-07T05:38:44.006291Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:lightrag:Logger initialized for working directory: ../llm_rag/paper_db/R000088_test2\n", - "INFO:lightrag:Load KV llm_response_cache with 0 data\n", - "INFO:lightrag:Load KV full_docs with 0 data\n", - "INFO:lightrag:Load KV text_chunks with 0 data\n", - "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../llm_rag/paper_db/R000088_test2/vdb_entities.json'} 0 data\n", - "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../llm_rag/paper_db/R000088_test2/vdb_relationships.json'} 0 data\n", - "INFO:nano-vectordb:Init {'embedding_dim': 4096, 'metric': 'cosine', 'storage_file': '../llm_rag/paper_db/R000088_test2/vdb_chunks.json'} 0 data\n", - "INFO:lightrag:Loaded document status storage with 0 records\n" - ] - } - ], - "source": [ - "rag = LightRAG(\n", - " working_dir=WORKING_DIR,\n", - " llm_model_func=llm_model_func,\n", - " embedding_func=EmbeddingFunc(\n", - " embedding_dim=4096, max_token_size=8192, func=embedding_func\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "1dc3603677f7484d", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:38:47.509111Z", - "start_time": "2025-01-07T05:38:47.501997Z" - } - }, - "outputs": [], - "source": [ - "with open(\n", - " \"../llm_rag/example/R000088/auto/R000088_full_txt.md\", \"r\", encoding=\"utf-8\"\n", - ") as f:\n", - " content = f.read()\n", - "\n", - "\n", - "async def embedding_func(texts: list[str]) -> np.ndarray:\n", - " return await openai_embedding(\n", - " texts,\n", - " model=\"ep-20241231173413-pgjmk\",\n", - " api_key=os.getenv(\"doubao_api\"),\n", - " base_url=\"https://ark.cn-beijing.volces.com/api/v3\",\n", - " )\n", - "\n", - "\n", - "async def get_embedding_dim():\n", - " test_text = [\"This is a test sentence.\"]\n", - " embedding = await embedding_func(test_text)\n", - " embedding_dim = embedding.shape[1]\n", - " return embedding_dim" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6844202606acfbe5", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:38:50.666764Z", - "start_time": "2025-01-07T05:38:50.247712Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n" - ] - } - ], - "source": [ - "embedding_dimension = await get_embedding_dim()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "d6273839d9681403", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:42:33.085507Z", - "start_time": "2025-01-07T05:38:56.789348Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:lightrag:Processing 1 new unique documents\n", - "Processing batch 1: 0%| | 0/1 [00:00标签中,针对每个问题详细分析你的思考过程。然后在<回答>标签中给出所有问题的最终答案。\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7a6491385b050095", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-07T05:43:24.751628Z", - "start_time": "2025-01-07T05:42:50.865679Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/chat/completions \"HTTP/1.1 200 OK\"\n", - "INFO:lightrag:kw_prompt result:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{\n", - " \"high_level_keywords\": [\"英文学术研究论文分析\", \"关键信息提取\", \"深入分析\"],\n", - " \"low_level_keywords\": [\"研究队列\", \"队列名称\", \"队列开展国家\", \"性别分布\", \"年龄分布\", \"队列研究时间线\", \"实际参与研究人数\"]\n", - "}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n", - "INFO:lightrag:Local query uses 60 entites, 38 relations, 6 text units\n", - "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/embeddings \"HTTP/1.1 200 OK\"\n", - "INFO:lightrag:Global query uses 72 entites, 60 relations, 4 text units\n", - "INFO:httpx:HTTP Request: POST https://ark.cn-beijing.volces.com/api/v3/chat/completions \"HTTP/1.1 200 OK\"\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "<分析>\n", - "- **分析对象来自哪些研究队列及是单独分析还是联合分析**:\n", - " 通过查找论文内容,发现文中提到“This is a combined analysis of data from 2 randomized, double-blind, placebo-controlled clinical trials (Norwegian Vitamin [NORVIT] trial15 and Western Norway B Vitamin Intervention Trial [WENBIT]16)”,明确是对两个队列的数据进行联合分析,队列名称分别为“Norwegian Vitamin (NORVIT) trial”和“Western Norway B Vitamin Intervention Trial (WENBIT)”。\n", - "- **队列开展的国家**:\n", - " 文中多次提及研究在挪威进行,如“combined analyses and extended follow-up of 2 vitamin B intervention trials among patients with ischemic heart disease in Norway”,所以确定研究开展的国家是挪威。\n", - "- **队列研究对象的性别分布**:\n", - " 从“Mean (SD) age was 62.3 (11.0) years and 23.5% of participants were women”可知,研究对象包含男性和女性,即全体。\n", - "- **队列收集结束时研究对象年龄分布**:\n", - " 已知“Mean (SD) age was 62.3 (11.0) years”是基线时年龄信息,“Median (interquartile range) duration of extended follow-up through December 31, 2007, was 78 (61 - 90) months”,由于随访的中位时间是78个月(约6.5年),所以可推算队列收集结束时研究对象年龄均值约为62.3 + 6.5 = 68.8岁(标准差仍为11.0年)。\n", - "- **队列研究时间线**:\n", - " 根据“2 randomized, double-blind, placebo-controlled clinical trials (Norwegian Vitamin [NORVIT] trial15 and Western Norway B Vitamin Intervention Trial [WENBIT]16) conducted between 1998 and 2005, and an observational posttrial follow-up through December 31, 2007”可知,队列开始收集信息时间为1998年,结束时间为2007年12月31日。\n", - "- **队列结束时实际参与研究人数**:\n", - " 由“A total of 6837 individuals were included in the combined analyses, of whom 6261 (91.6%) participated in posttrial follow-up”可知,队列结束时实际参与研究人数为6261人。\n", - "\n", - "\n", - "<回答>\n", - "- 分析对象来自“Norwegian Vitamin (NORVIT) trial”和“Western Norway B Vitamin Intervention Trial (WENBIT)”两个研究队列,文中是对这两个队列的数据进行联合分析。\n", - "- 队列开展的国家是挪威。\n", - "- 队列研究对象的性别分布为全体。\n", - "- 队列收集结束时,研究对象年龄分布均值约为68.8岁,标准差为11.0年。\n", - "- 队列研究时间线为1998年开始收集信息/建立队列,2007年12月31日结束。\n", - "- 队列结束时实际参与研究人数是6261人。\n" - ] - } - ], - "source": [ - "print(rag.query(prompt1, param=QueryParam(mode=\"hybrid\")))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fef9d06983da47af", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From acde4ed173614bca12f50a7a2f185b7f6f0ef2c1 Mon Sep 17 00:00:00 2001 From: adikalra <54812001+AdiKalra@users.noreply.github.com> Date: Thu, 9 Jan 2025 17:20:24 +0530 Subject: [PATCH 15/38] Add custom chunking function. --- lightrag/lightrag.py | 7 ++++++- lightrag/operate.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9905ee74..596fbdbf 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -187,6 +187,10 @@ class LightRAG: # Add new field for document status storage type doc_status_storage: str = field(default="JsonDocStatusStorage") + # Custom Chunking Function + chunking_func: callable = chunking_by_token_size + chunking_func_kwargs: dict = field(default_factory=dict) + def __post_init__(self): log_file = os.path.join("lightrag.log") set_logger(log_file) @@ -388,13 +392,14 @@ class LightRAG: **dp, "full_doc_id": doc_id, } - for dp in chunking_by_token_size( + for dp in self.chunking_func( doc["content"], split_by_character=split_by_character, split_by_character_only=split_by_character_only, overlap_token_size=self.chunk_overlap_token_size, max_token_size=self.chunk_token_size, tiktoken_model=self.tiktoken_model_name, + **self.chunking_func_kwargs, ) } diff --git a/lightrag/operate.py b/lightrag/operate.py index 09871659..7216c07f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -39,6 +39,7 @@ def chunking_by_token_size( overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o", + **kwargs, ): tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) results = [] From 2297007b7b240e17fa77e4dc5aad228a8b0a1b65 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 20:30:58 +0100 Subject: [PATCH 16/38] Simplified the api services issue #565 --- README.md | 133 ++--- lightrag/api/azure_openai_lightrag_server.py | 532 ------------------ ..._lightrag_server.py => lightrag_server.py} | 94 +++- lightrag/api/ollama_lightrag_server.py | 491 ---------------- lightrag/api/openai_lightrag_server.py | 506 ----------------- setup.py | 5 +- 6 files changed, 136 insertions(+), 1625 deletions(-) delete mode 100644 lightrag/api/azure_openai_lightrag_server.py rename lightrag/api/{lollms_lightrag_server.py => lightrag_server.py} (82%) delete mode 100644 lightrag/api/ollama_lightrag_server.py delete mode 100644 lightrag/api/openai_lightrag_server.py diff --git a/README.md b/README.md index 6c981d92..278f6a72 100644 --- a/README.md +++ b/README.md @@ -912,12 +912,14 @@ pip install -e ".[api]" ### Prerequisites -Before running any of the servers, ensure you have the corresponding backend service running: +Before running any of the servers, ensure you have the corresponding backend service running for both llm and embedding. +The new api allows you to mix different bindings for llm/embeddings. +For example, you have the possibility to use ollama for the embedding and openai for the llm. #### For LoLLMs Server - LoLLMs must be running and accessible - Default connection: http://localhost:9600 -- Configure using --lollms-host if running on a different host/port +- Configure using --llm-binding-host and/or --embedding-binding-host if running on a different host/port #### For Ollama Server - Ollama must be running and accessible @@ -953,15 +955,19 @@ The output of the last command will give you the endpoint and the key for the Op Each server has its own specific configuration options: -#### LoLLMs Server Options +#### LightRag Server Options | Parameter | Default | Description | |-----------|---------|-------------| | --host | 0.0.0.0 | RAG server host | | --port | 9621 | RAG server port | +| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai (default: ollama) | +| --llm-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | llm server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | | --model | mistral-nemo:latest | LLM model name | +| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama) | +| --embedding-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | embedding server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | | --embedding-model | bge-m3:latest | Embedding model name | -| --lollms-host | http://localhost:9600 | LoLLMS backend URL | +| --embedding-binding-host | http://localhost:9600 | LoLLMS backend URL | | --working-dir | ./rag_storage | Working directory for RAG | | --max-async | 4 | Maximum async operations | | --max-tokens | 32768 | Maximum token size | @@ -971,95 +977,71 @@ Each server has its own specific configuration options: | --log-level | INFO | Logging level | | --key | none | Access Key to protect the lightrag service | -#### Ollama Server Options - -| Parameter | Default | Description | -|-----------|---------|-------------| -| --host | 0.0.0.0 | RAG server host | -| --port | 9621 | RAG server port | -| --model | mistral-nemo:latest | LLM model name | -| --embedding-model | bge-m3:latest | Embedding model name | -| --ollama-host | http://localhost:11434 | Ollama backend URL | -| --working-dir | ./rag_storage | Working directory for RAG | -| --max-async | 4 | Maximum async operations | -| --max-tokens | 32768 | Maximum token size | -| --embedding-dim | 1024 | Embedding dimensions | -| --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-file | ./book.txt | Initial input file | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | - -#### OpenAI Server Options - -| Parameter | Default | Description | -|-----------|---------|-------------| -| --host | 0.0.0.0 | RAG server host | -| --port | 9621 | RAG server port | -| --model | gpt-4 | OpenAI model name | -| --embedding-model | text-embedding-3-large | OpenAI embedding model | -| --working-dir | ./rag_storage | Working directory for RAG | -| --max-tokens | 32768 | Maximum token size | -| --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-dir | ./inputs | Input directory for documents | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | - -#### OpenAI AZURE Server Options - -| Parameter | Default | Description | -|-----------|---------|-------------| -| --host | 0.0.0.0 | Server host | -| --port | 9621 | Server port | -| --model | gpt-4 | OpenAI model name | -| --embedding-model | text-embedding-3-large | OpenAI embedding model | -| --working-dir | ./rag_storage | Working directory for RAG | -| --max-tokens | 32768 | Maximum token size | -| --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-dir | ./inputs | Input directory for documents | -| --enable-cache | True | Enable response cache | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`. ### Example Usage -#### LoLLMs RAG Server +#### Running a Lightrag server with ollama default local server as llm and embedding backends + +Ollama is the default backend for both llm and embedding, so by default you can run lightrag-server with no parameters and the default ones will be used. Make sure ollama is installed and is running and default models are already installed on ollama. ```bash -# Custom configuration with specific model and working directory -lollms-lightrag-server --model mistral-nemo --port 8080 --working-dir ./custom_rag +# Run lightrag with ollama, mistral-nemo:latest for llm, and bge-m3:latest for embedding +lightrag-server -# Using specific models (ensure they are installed in your LoLLMs instance) -lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 +# Using specific models (ensure they are installed in your ollama instance) +lightrag-server --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-model nomic-embed-text --embedding-dim 1024 -# Using specific models and an authentication key -lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 --key ky-mykey +# Using an authentication key +lightrag-server --key my-key +# Using lollms for llm and ollama for embedding +lightrag-server --llm-binding lollms ``` -#### Ollama RAG Server +#### Running a Lightrag server with lollms default local server as llm and embedding backends ```bash -# Custom configuration with specific model and working directory -ollama-lightrag-server --model mistral-nemo:latest --port 8080 --working-dir ./custom_rag +# Run lightrag with lollms, mistral-nemo:latest for llm, and bge-m3:latest for embedding, use lollms for both llm and embedding +lightrag-server --llm-binding lollms --embedding-binding lollms -# Using specific models (ensure they are installed in your Ollama instance) -ollama-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 +# Using specific models (ensure they are installed in your ollama instance) +lightrag-server --llm-binding lollms --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-binding lollms --embedding-model nomic-embed-text --embedding-dim 1024 + +# Using an authentication key +lightrag-server --key my-key + +# Using lollms for llm and openai for embedding +lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small ``` -#### OpenAI RAG Server + +#### Running a Lightrag server with openai server as llm and embedding backends ```bash -# Using GPT-4 with text-embedding-3-large -openai-lightrag-server --port 9624 --model gpt-4 --embedding-model text-embedding-3-large -``` -#### Azure OpenAI RAG Server -```bash -# Using GPT-4 with text-embedding-3-large -azure-openai-lightrag-server --model gpt-4o --port 8080 --working-dir ./custom_rag --embedding-model text-embedding-3-large +# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding +lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small + +# Using an authentication key +lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small --key my-key + +# Using lollms for llm and openai for embedding +lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small ``` +#### Running a Lightrag server with azure openai server as llm and embedding backends + +```bash +# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding +lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small + +# Using an authentication key +lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding azure_openai --embedding-model text-embedding-3-small --key my-key + +# Using lollms for llm and azure_openai for embedding +lightrag-server --llm-binding lollms --embedding-binding azure_openai --embedding-model text-embedding-3-small +``` **Important Notes:** - For LoLLMs: Make sure the specified models are installed in your LoLLMs instance @@ -1069,10 +1051,7 @@ azure-openai-lightrag-server --model gpt-4o --port 8080 --working-dir ./custom_r For help on any server, use the --help flag: ```bash -lollms-lightrag-server --help -ollama-lightrag-server --help -openai-lightrag-server --help -azure-openai-lightrag-server --help +lightrag-server --help ``` Note: If you don't need the API functionality, you can install the base package without API support using: @@ -1092,7 +1071,7 @@ Query the RAG system with options for different search modes. ```bash curl -X POST "http://localhost:9621/query" \ -H "Content-Type: application/json" \ - -d '{"query": "Your question here", "mode": "hybrid"}' + -d '{"query": "Your question here", "mode": "hybrid", ""}' ``` #### POST /query/stream diff --git a/lightrag/api/azure_openai_lightrag_server.py b/lightrag/api/azure_openai_lightrag_server.py deleted file mode 100644 index abe3f738..00000000 --- a/lightrag/api/azure_openai_lightrag_server.py +++ /dev/null @@ -1,532 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form -from pydantic import BaseModel -import asyncio -import logging -import argparse -from lightrag import LightRAG, QueryParam -from lightrag.llm import ( - azure_openai_complete_if_cache, - azure_openai_embedding, -) -from lightrag.utils import EmbeddingFunc -from typing import Optional, List -from enum import Enum -from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception -import os -from dotenv import load_dotenv -import inspect -import json -from fastapi.responses import StreamingResponse - -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware - -from starlette.status import HTTP_403_FORBIDDEN - -load_dotenv() - -AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") -AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") -AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") -AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") - -AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") -AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") - - -def parse_args(): - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with OpenAI integration" - ) - - # Server configuration - parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", - ) - - # Model configuration - parser.add_argument( - "--model", default="gpt-4o", help="OpenAI model name (default: gpt-4o)" - ) - parser.add_argument( - "--embedding-model", - default="text-embedding-3-large", - help="OpenAI embedding model (default: text-embedding-3-large)", - ) - - # RAG configuration - parser.add_argument( - "--max-tokens", - type=int, - default=32768, - help="Maximum token size (default: 32768)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", - ) - parser.add_argument( - "--enable-cache", - default=True, - help="Enable response cache (default: True)", - ) - # Logging configuration - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", - ) - - parser.add_argument( - "--key", - type=str, - help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, - ) - - return parser.parse_args() - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -# Pydantic models -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - - -class QueryRequest(BaseModel): - query: str - mode: SearchMode = SearchMode.hybrid - only_need_context: bool = False - # stream: bool = False - - -class QueryResponse(BaseModel): - response: str - - -class InsertTextRequest(BaseModel): - text: str - description: Optional[str] = None - - -class InsertResponse(BaseModel): - status: str - message: str - document_count: int - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -async def get_embedding_dim(embedding_model: str) -> int: - """Get embedding dimensions for the specified model""" - test_text = ["This is a test sentence."] - embedding = await azure_openai_embedding(test_text, model=embedding_model) - return embedding.shape[1] - - -def create_app(args): - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - - # Check if API key is provided either through env var or args - api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - - # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" - + "(With authentication)" - if api_key - else "", - version="1.0.0", - openapi_tags=[{"name": "api"}], - ) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) - - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) - - # Get embedding dimensions - embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model)) - - async def async_openai_complete( - prompt, system_prompt=None, history_messages=[], **kwargs - ): - """Async wrapper for OpenAI completion""" - kwargs.pop("keyword_extraction", None) - - return await azure_openai_complete_if_cache( - args.model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - base_url=AZURE_OPENAI_ENDPOINT, - api_key=AZURE_OPENAI_API_KEY, - api_version=AZURE_OPENAI_API_VERSION, - **kwargs, - ) - - # Initialize RAG with OpenAI configuration - rag = LightRAG( - enable_llm_cache=args.enable_cache, - working_dir=args.working_dir, - llm_model_func=async_openai_complete, - llm_model_name=args.model, - llm_model_max_token_size=args.max_tokens, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: azure_openai_embedding( - texts, model=args.embedding_model - ), - ), - ) - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" - try: - new_files = doc_manager.scan_directory() - for file_path in new_files: - try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") - - logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """Manually trigger scanning for new documents""" - try: - new_files = doc_manager.scan_directory() - indexed_count = 0 - - for file_path in new_files: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/resetcache", dependencies=[Depends(optional_api_key)]) - async def reset_cache(): - """Manually reset cache""" - try: - cachefile = args.working_dir + "/kv_store_llm_response_cache.json" - if os.path.exists(cachefile): - with open(cachefile, "w") as f: - f.write("{}") - return {"status": "success"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): - """Upload a file to the input directory""" - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=False, - only_need_context=request.only_need_context, - ), - ) - return QueryResponse(response=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - ), - ) - if inspect.isasyncgen(response): - - async def stream_generator(): - async for chunk in response: - yield json.dumps({"data": chunk}) + "\n" - - return StreamingResponse( - stream_generator(), media_type="application/json" - ) - else: - return QueryResponse(response=response) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - try: - await rag.ainsert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=1, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - try: - content = await file.read() - - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - else: - raise HTTPException( - status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", - ) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", - message=status_message, - document_count=len(files), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.get("/health", dependencies=[Depends(optional_api_key)]) - async def get_status(): - """Get current system status""" - return { - "status": "healthy", - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "indexed_files": len(doc_manager.indexed_files), - "configuration": { - "model": args.model, - "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, - "embedding_dim": embedding_dim, - }, - } - - return app - - -def main(): - args = parse_args() - import uvicorn - - app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/lightrag/api/lollms_lightrag_server.py b/lightrag/api/lightrag_server.py similarity index 82% rename from lightrag/api/lollms_lightrag_server.py rename to lightrag/api/lightrag_server.py index 8a2804a0..4f8e38cd 100644 --- a/lightrag/api/lollms_lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -4,6 +4,10 @@ import logging import argparse from lightrag import LightRAG, QueryParam from lightrag.llm import lollms_model_complete, lollms_embed +from lightrag.llm import ollama_model_complete, ollama_embed +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding + from lightrag.utils import EmbeddingFunc from typing import Optional, List from enum import Enum @@ -19,12 +23,36 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN +def get_default_host(binding_type: str) -> str: + default_hosts = { + "ollama": "http://localhost:11434", + "lollms": "http://localhost:9600", + "azure_openai": "https://api.openai.com/v1", + "openai": "https://api.openai.com/v1" + } + return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown def parse_args(): parser = argparse.ArgumentParser( description="LightRAG FastAPI Server with separate working and input directories" ) + #Start by the bindings + parser.add_argument( + "--llm-binding", + default="ollama", + help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + parser.add_argument( + "--embedding-binding", + default="ollama", + help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + + # Parse just these arguments first + temp_args, _ = parser.parse_known_args() + + # Add remaining arguments with dynamic defaults for hosts # Server configuration parser.add_argument( "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" @@ -45,22 +73,33 @@ def parse_args(): help="Directory containing input documents (default: ./inputs)", ) - # Model configuration + # LLM Model configuration + default_llm_host = get_default_host(temp_args.llm_binding) parser.add_argument( - "--model", + "--llm-binding-host", + default=default_llm_host, + help=f"llm server host URL (default: {default_llm_host})", + ) + + parser.add_argument( + "--llm-model", default="mistral-nemo:latest", help="LLM model name (default: mistral-nemo:latest)", ) + + # Embedding model configuration + default_embedding_host = get_default_host(temp_args.embedding_binding) + parser.add_argument( + "--embedding-binding-host", + default=default_embedding_host, + help=f"embedding server host URL (default: {default_embedding_host})", + ) + parser.add_argument( "--embedding-model", default="bge-m3:latest", help="Embedding model name (default: bge-m3:latest)", ) - parser.add_argument( - "--lollms-host", - default="http://localhost:9600", - help="lollms host URL (default: http://localhost:9600)", - ) # RAG configuration parser.add_argument( @@ -188,6 +227,15 @@ def get_api_key_dependency(api_key: Optional[str]): def create_app(args): + # Verify that bindings arer correctly setup + if args.llm_binding not in ["lollms", "ollama", "openai"]: + raise Exception("llm binding not supported") + + if args.embedding_binding not in ["lollms", "ollama", "openai"]: + raise Exception("embedding binding not supported") + + + # Setup logging logging.basicConfig( format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) @@ -203,7 +251,7 @@ def create_app(args): + "(With authentication)" if api_key else "", - version="1.0.0", + version="1.0.1", openapi_tags=[{"name": "api"}], ) @@ -225,23 +273,32 @@ def create_app(args): # Initialize document manager doc_manager = DocumentManager(args.input_dir) + + # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=lollms_model_complete, - llm_model_name=args.model, + llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache, + llm_model_name=args.llm_model, llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ - "host": args.lollms_host, + "host": args.llm_binding_host, "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( embedding_dim=args.embedding_dim, max_token_size=args.max_embed_tokens, func=lambda texts: lollms_embed( - texts, embed_model=args.embedding_model, host=args.lollms_host - ), + texts, embed_model=args.embedding_model, host=args.embedding_binding_host + ) if args.llm_binding=="lollms" else ollama_embed( + texts, embed_model=args.embedding_model, host=args.embedding_binding_host + ) if args.llm_binding=="ollama" else azure_openai_embedding( + texts, model=args.embedding_model # no host is used for openai + ) if args.llm_binding=="azure_openai" else openai_embedding( + texts, model=args.embedding_model # no host is used for openai + ) + ), ) @@ -470,10 +527,17 @@ def create_app(args): "input_directory": str(args.input_dir), "indexed_files": len(doc_manager.indexed_files), "configuration": { - "model": args.model, + # LLM configuration binding/host address (if applicable)/model (if applicable) + "llm_binding": args.llm_binding, + "llm_binding_host": args.llm_binding_host, + "llm_model": args.llm_model, + + # embedding model configuration binding/host address (if applicable)/model (if applicable) + "embedding_binding": args.embedding_binding, + "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, + "max_tokens": args.max_tokens, - "lollms_host": args.lollms_host, }, } diff --git a/lightrag/api/ollama_lightrag_server.py b/lightrag/api/ollama_lightrag_server.py deleted file mode 100644 index b3140aba..00000000 --- a/lightrag/api/ollama_lightrag_server.py +++ /dev/null @@ -1,491 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form -from pydantic import BaseModel -import logging -import argparse -from lightrag import LightRAG, QueryParam -from lightrag.llm import ollama_model_complete, ollama_embed -from lightrag.utils import EmbeddingFunc -from typing import Optional, List -from enum import Enum -from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception -import os - -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware - -from starlette.status import HTTP_403_FORBIDDEN - - -def parse_args(): - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) - - # Server configuration - parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", - ) - - # Model configuration - parser.add_argument( - "--model", - default="mistral-nemo:latest", - help="LLM model name (default: mistral-nemo:latest)", - ) - parser.add_argument( - "--embedding-model", - default="bge-m3:latest", - help="Embedding model name (default: bge-m3:latest)", - ) - parser.add_argument( - "--ollama-host", - default="http://localhost:11434", - help="Ollama host URL (default: http://localhost:11434)", - ) - - # RAG configuration - parser.add_argument( - "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" - ) - parser.add_argument( - "--max-tokens", - type=int, - default=32768, - help="Maximum token size (default: 32768)", - ) - parser.add_argument( - "--embedding-dim", - type=int, - default=1024, - help="Embedding dimensions (default: 1024)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", - ) - parser.add_argument( - "--key", - type=str, - help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, - ) - - return parser.parse_args() - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -# Pydantic models -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - - -class QueryRequest(BaseModel): - query: str - mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False - - -class QueryResponse(BaseModel): - response: str - - -class InsertTextRequest(BaseModel): - text: str - description: Optional[str] = None - - -class InsertResponse(BaseModel): - status: str - message: str - document_count: int - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -def create_app(args): - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - - # Check if API key is provided either through env var or args - api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - - # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" - + "(With authentication)" - if api_key - else "", - version="1.0.0", - openapi_tags=[{"name": "api"}], - ) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) - - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) - - # Initialize RAG - rag = LightRAG( - working_dir=args.working_dir, - llm_model_func=ollama_model_complete, - llm_model_name=args.model, - llm_model_max_async=args.max_async, - llm_model_max_token_size=args.max_tokens, - llm_model_kwargs={ - "host": args.ollama_host, - "options": {"num_ctx": args.max_tokens}, - }, - embedding_func=EmbeddingFunc( - embedding_dim=args.embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: ollama_embed( - texts, embed_model=args.embedding_model, host=args.ollama_host - ), - ), - ) - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" - try: - new_files = doc_manager.scan_directory() - for file_path in new_files: - try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") - - logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """Manually trigger scanning for new documents""" - try: - new_files = doc_manager.scan_directory() - indexed_count = 0 - - for file_path in new_files: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): - """Upload a file to the input directory""" - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - ), - ) - - if request.stream: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) - else: - return QueryResponse(response=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - try: - response = rag.query( - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - ), - ) - - async def stream_generator(): - async for chunk in response: - yield chunk - - return stream_generator() - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - try: - await rag.ainsert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=len(rag), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - try: - content = await file.read() - - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) - else: - raise HTTPException( - status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", - ) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", - message=status_message, - document_count=len(files), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.get("/health", dependencies=[Depends(optional_api_key)]) - async def get_status(): - """Get current system status""" - return { - "status": "healthy", - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "indexed_files": len(doc_manager.indexed_files), - "configuration": { - "model": args.model, - "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, - "ollama_host": args.ollama_host, - }, - } - - return app - - -def main(): - args = parse_args() - import uvicorn - - app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py deleted file mode 100644 index 349c09da..00000000 --- a/lightrag/api/openai_lightrag_server.py +++ /dev/null @@ -1,506 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form -from pydantic import BaseModel -import asyncio -import logging -import argparse -from lightrag import LightRAG, QueryParam -from lightrag.llm import openai_complete_if_cache, openai_embedding -from lightrag.utils import EmbeddingFunc -from typing import Optional, List -from enum import Enum -from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception -import nest_asyncio - -import os - -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware - -from starlette.status import HTTP_403_FORBIDDEN - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - - -def parse_args(): - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with OpenAI integration" - ) - - # Server configuration - parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", - ) - - # Model configuration - parser.add_argument( - "--model", default="gpt-4", help="OpenAI model name (default: gpt-4)" - ) - parser.add_argument( - "--embedding-model", - default="text-embedding-3-large", - help="OpenAI embedding model (default: text-embedding-3-large)", - ) - - # RAG configuration - parser.add_argument( - "--max-tokens", - type=int, - default=32768, - help="Maximum token size (default: 32768)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", - ) - - parser.add_argument( - "--key", - type=str, - help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, - ) - - return parser.parse_args() - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -# Pydantic models -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - - -class QueryRequest(BaseModel): - query: str - mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False - - -class QueryResponse(BaseModel): - response: str - - -class InsertTextRequest(BaseModel): - text: str - description: Optional[str] = None - - -class InsertResponse(BaseModel): - status: str - message: str - document_count: int - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -async def get_embedding_dim(embedding_model: str) -> int: - """Get embedding dimensions for the specified model""" - test_text = ["This is a test sentence."] - embedding = await openai_embedding(test_text, model=embedding_model) - return embedding.shape[1] - - -def create_app(args): - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - - # Check if API key is provided either through env var or args - api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - - # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" - + "(With authentication)" - if api_key - else "", - version="1.0.0", - openapi_tags=[{"name": "api"}], - ) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) - - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) - - # Get embedding dimensions - embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model)) - - async def async_openai_complete( - prompt, system_prompt=None, history_messages=[], **kwargs - ): - """Async wrapper for OpenAI completion""" - return await openai_complete_if_cache( - args.model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - # Initialize RAG with OpenAI configuration - rag = LightRAG( - working_dir=args.working_dir, - llm_model_func=async_openai_complete, - llm_model_name=args.model, - llm_model_max_token_size=args.max_tokens, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: openai_embedding(texts, model=args.embedding_model), - ), - ) - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" - try: - new_files = doc_manager.scan_directory() - for file_path in new_files: - try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") - - logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """Manually trigger scanning for new documents""" - try: - new_files = doc_manager.scan_directory() - indexed_count = 0 - - for file_path in new_files: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - rag.insert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): - """Upload a file to the input directory""" - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - rag.insert(content) - doc_manager.mark_as_indexed(file_path) - - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - ), - ) - - if request.stream: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) - else: - return QueryResponse(response=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - try: - response = rag.query( - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - ), - ) - - async def stream_generator(): - async for chunk in response: - yield chunk - - return stream_generator() - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - try: - rag.insert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=len(rag), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - try: - content = await file.read() - - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - else: - raise HTTPException( - status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", - ) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", - message=status_message, - document_count=len(files), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.get("/health", dependencies=[Depends(optional_api_key)]) - async def get_status(): - """Get current system status""" - return { - "status": "healthy", - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "indexed_files": len(doc_manager.indexed_files), - "configuration": { - "model": args.model, - "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, - "embedding_dim": embedding_dim, - }, - } - - return app - - -def main(): - args = parse_args() - import uvicorn - - app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/setup.py b/setup.py index 368610f6..38eff646 100644 --- a/setup.py +++ b/setup.py @@ -100,10 +100,7 @@ setuptools.setup( }, entry_points={ "console_scripts": [ - "lollms-lightrag-server=lightrag.api.lollms_lightrag_server:main [api]", - "ollama-lightrag-server=lightrag.api.ollama_lightrag_server:main [api]", - "openai-lightrag-server=lightrag.api.openai_lightrag_server:main [api]", - "azure-openai-lightrag-server=lightrag.api.azure_openai_lightrag_server:main [api]", + "lightrag-server=lightrag.api.lightrag_server:main [api]", ], }, ) From adb288c5bb7ac6753daf898cb99c1458f9663773 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 21:39:25 +0100 Subject: [PATCH 17/38] added timeout --- lightrag/api/lightrag_server.py | 23 +++++++++++++++++++++++ lightrag/llm.py | 4 +++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 4f8e38cd..1175afab 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -101,6 +101,12 @@ def parse_args(): help="Embedding model name (default: bge-m3:latest)", ) + parser.add_argument( + "--timeout", + default=300, + help="Timeout is seconds (useful when using slow AI)", + ) + # RAG configuration parser.add_argument( "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" @@ -139,6 +145,22 @@ def parse_args(): default=None, ) + # Optional https parameters + parser.add_argument( + "--ssl", + action="store_true", + help="Enable HTTPS (default: False)" + ) + parser.add_argument( + "--ssl-certfile", + default=None, + help="Path to SSL certificate file (required if --ssl is enabled)" + ) + parser.add_argument( + "--ssl-keyfile", + default=None, + help="Path to SSL private key file (required if --ssl is enabled)" + ) return parser.parse_args() @@ -284,6 +306,7 @@ def create_app(args): llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ "host": args.llm_binding_host, + "timeout":args.timeout "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( diff --git a/lightrag/llm.py b/lightrag/llm.py index 0c17019a..4e01dd51 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -336,6 +336,7 @@ async def hf_model_if_cache( (RateLimitError, APIConnectionError, APITimeoutError) ), ) + async def ollama_model_if_cache( model, prompt, @@ -406,8 +407,9 @@ async def lollms_model_if_cache( full_prompt += prompt request_data["prompt"] = full_prompt + timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", 300)) # 300 seconds = 5 minutes - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=timeout) as session: if stream: async def inner(): From ab3cc3f0f47790ea3a713b2a3831aac1efb3c854 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 21:39:41 +0100 Subject: [PATCH 18/38] fixed missing coma --- lightrag/api/lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1175afab..d4cddd6c 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -306,7 +306,7 @@ def create_app(args): llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ "host": args.llm_binding_host, - "timeout":args.timeout + "timeout":args.timeout, "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( From a619b010640a356d95e241ff07e17717ee4c2fe1 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 22:17:13 +0100 Subject: [PATCH 19/38] Next test of timeout --- lightrag/api/lightrag_server.py | 11 ++++++++--- lightrag/llm.py | 3 +-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d4cddd6c..40b63463 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -101,12 +101,17 @@ def parse_args(): help="Embedding model name (default: bge-m3:latest)", ) + def timeout_type(value): + if value is None or value == "None": + return None + return int(value) + parser.add_argument( "--timeout", - default=300, - help="Timeout is seconds (useful when using slow AI)", + default=None, + type=timeout_type, + help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", ) - # RAG configuration parser.add_argument( "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" diff --git a/lightrag/llm.py b/lightrag/llm.py index 4e01dd51..7a51d025 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -407,11 +407,10 @@ async def lollms_model_if_cache( full_prompt += prompt request_data["prompt"] = full_prompt - timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", 300)) # 300 seconds = 5 minutes + timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None)) async with aiohttp.ClientSession(timeout=timeout) as session: if stream: - async def inner(): async with session.post( f"{base_url}/lollms_generate", json=request_data From e21fbef60b702e3205dbdc3d92c680bc7d2b90c5 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 22:38:57 +0100 Subject: [PATCH 20/38] updated documlentation --- README.md | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 278f6a72..57aee435 100644 --- a/README.md +++ b/README.md @@ -959,23 +959,26 @@ Each server has its own specific configuration options: | Parameter | Default | Description | |-----------|---------|-------------| -| --host | 0.0.0.0 | RAG server host | -| --port | 9621 | RAG server port | -| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai (default: ollama) | -| --llm-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | llm server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | -| --model | mistral-nemo:latest | LLM model name | -| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama) | -| --embedding-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | embedding server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | +| --host | 0.0.0.0 | Server host | +| --port | 9621 | Server port | +| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai | +| --llm-binding-host | (dynamic) | LLM server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) | +| --llm-model | mistral-nemo:latest | LLM model name | +| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai | +| --embedding-binding-host | (dynamic) | Embedding server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) | | --embedding-model | bge-m3:latest | Embedding model name | -| --embedding-binding-host | http://localhost:9600 | LoLLMS backend URL | -| --working-dir | ./rag_storage | Working directory for RAG | +| --working-dir | ./rag_storage | Working directory for RAG storage | +| --input-dir | ./inputs | Directory containing input documents | | --max-async | 4 | Maximum async operations | | --max-tokens | 32768 | Maximum token size | | --embedding-dim | 1024 | Embedding dimensions | | --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-file | ./book.txt | Initial input file | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | +| --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout | +| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | +| --key | None | API key for authentication. Protects lightrag server against unauthorized access | +| --ssl | False | Enable HTTPS | +| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | +| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | From e0e656ab014138c129aab9f48e7f3f8bcf6b57b7 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sat, 11 Jan 2025 01:35:49 +0100 Subject: [PATCH 21/38] Added ssl support --- lightrag/api/lightrag_server.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 40b63463..1f88e776 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -262,7 +262,16 @@ def create_app(args): raise Exception("embedding binding not supported") - + # Add SSL validation + if args.ssl: + if not args.ssl_certfile or not args.ssl_keyfile: + raise Exception("SSL certificate and key files must be provided when SSL is enabled") + if not os.path.exists(args.ssl_certfile): + raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") + if not os.path.exists(args.ssl_keyfile): + raise Exception(f"SSL key file not found: {args.ssl_keyfile}") + + # Setup logging logging.basicConfig( format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) @@ -577,7 +586,17 @@ def main(): import uvicorn app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) + uvicorn_config = { + "app": app, + "host": args.host, + "port": args.port, + } + if args.ssl: + uvicorn_config.update({ + "ssl_certfile": args.ssl_certfile, + "ssl_keyfile": args.ssl_keyfile, + }) + uvicorn.run(**uvicorn_config) if __name__ == "__main__": From 224fce9b1b1a887a998d2ee818f0855c950422de Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sat, 11 Jan 2025 01:37:07 +0100 Subject: [PATCH 22/38] run precommit to fix linting issues --- lightrag/api/lightrag_server.py | 83 ++++++++++++++++++++------------- lightrag/llm.py | 2 +- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1f88e776..644e622d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -23,21 +23,25 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN + def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": "http://localhost:11434", "lollms": "http://localhost:9600", "azure_openai": "https://api.openai.com/v1", - "openai": "https://api.openai.com/v1" + "openai": "https://api.openai.com/v1", } - return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown + return default_hosts.get( + binding_type, "http://localhost:11434" + ) # fallback to ollama if unknown + def parse_args(): parser = argparse.ArgumentParser( description="LightRAG FastAPI Server with separate working and input directories" ) - #Start by the bindings + # Start by the bindings parser.add_argument( "--llm-binding", default="ollama", @@ -48,7 +52,7 @@ def parse_args(): default="ollama", help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", ) - + # Parse just these arguments first temp_args, _ = parser.parse_known_args() @@ -152,19 +156,17 @@ def parse_args(): # Optional https parameters parser.add_argument( - "--ssl", - action="store_true", - help="Enable HTTPS (default: False)" + "--ssl", action="store_true", help="Enable HTTPS (default: False)" ) parser.add_argument( "--ssl-certfile", default=None, - help="Path to SSL certificate file (required if --ssl is enabled)" + help="Path to SSL certificate file (required if --ssl is enabled)", ) parser.add_argument( "--ssl-keyfile", - default=None, - help="Path to SSL private key file (required if --ssl is enabled)" + default=None, + help="Path to SSL private key file (required if --ssl is enabled)", ) return parser.parse_args() @@ -261,17 +263,17 @@ def create_app(args): if args.embedding_binding not in ["lollms", "ollama", "openai"]: raise Exception("embedding binding not supported") - # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception("SSL certificate and key files must be provided when SSL is enabled") + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): raise Exception(f"SSL key file not found: {args.ssl_keyfile}") - - + # Setup logging logging.basicConfig( format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) @@ -309,33 +311,48 @@ def create_app(args): # Initialize document manager doc_manager = DocumentManager(args.input_dir) - - # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache, + llm_model_func=lollms_model_complete + if args.llm_binding == "lollms" + else ollama_model_complete + if args.llm_binding == "ollama" + else azure_openai_complete_if_cache + if args.llm_binding == "azure_openai" + else openai_complete_if_cache, llm_model_name=args.llm_model, llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ "host": args.llm_binding_host, - "timeout":args.timeout, + "timeout": args.timeout, "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( embedding_dim=args.embedding_dim, max_token_size=args.max_embed_tokens, func=lambda texts: lollms_embed( - texts, embed_model=args.embedding_model, host=args.embedding_binding_host - ) if args.llm_binding=="lollms" else ollama_embed( - texts, embed_model=args.embedding_model, host=args.embedding_binding_host - ) if args.llm_binding=="ollama" else azure_openai_embedding( - texts, model=args.embedding_model # no host is used for openai - ) if args.llm_binding=="azure_openai" else openai_embedding( - texts, model=args.embedding_model # no host is used for openai + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, ) - + if args.llm_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.llm_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.llm_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ), ), ) @@ -568,12 +585,10 @@ def create_app(args): "llm_binding": args.llm_binding, "llm_binding_host": args.llm_binding_host, "llm_model": args.llm_model, - # embedding model configuration binding/host address (if applicable)/model (if applicable) "embedding_binding": args.embedding_binding, "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, }, } @@ -590,12 +605,14 @@ def main(): "app": app, "host": args.host, "port": args.port, - } + } if args.ssl: - uvicorn_config.update({ - "ssl_certfile": args.ssl_certfile, - "ssl_keyfile": args.ssl_keyfile, - }) + uvicorn_config.update( + { + "ssl_certfile": args.ssl_certfile, + "ssl_keyfile": args.ssl_keyfile, + } + ) uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm.py b/lightrag/llm.py index 7a51d025..c49ed138 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -336,7 +336,6 @@ async def hf_model_if_cache( (RateLimitError, APIConnectionError, APITimeoutError) ), ) - async def ollama_model_if_cache( model, prompt, @@ -411,6 +410,7 @@ async def lollms_model_if_cache( async with aiohttp.ClientSession(timeout=timeout) as session: if stream: + async def inner(): async with session.post( f"{base_url}/lollms_generate", json=request_data From d03d6f5fc54a85ad460ee3f468351aa55077cc9b Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Sat, 11 Jan 2025 09:30:19 +0800 Subject: [PATCH 23/38] Revised the postgres implementation, to use attributes(node_id) rather than nodes to identify an entity. Which significantly reduced the table counts. --- README.md | 5 + lightrag/kg/postgres_impl.py | 268 +++++++++++++++-------------------- 2 files changed, 118 insertions(+), 155 deletions(-) diff --git a/README.md b/README.md index ea8d0a97..d6d22522 100644 --- a/README.md +++ b/README.md @@ -361,6 +361,11 @@ see test_neo4j.py for a working example. For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac. * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) +* Create index for AGE example: (Change below `dickens` to your graph name if necessary) + ``` + SET search_path = ag_catalog, "$user", public; + CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"')); + ``` ### Insert Custom KG diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 033d63d6..ccbff679 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -130,6 +130,7 @@ class PostgreSQLDB: data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None, + upsert: bool = False, ): try: async with self.pool.acquire() as connection: @@ -140,6 +141,11 @@ class PostgreSQLDB: await connection.execute(sql) else: await connection.execute(sql, *data.values()) + except asyncpg.exceptions.UniqueViolationError as e: + if upsert: + print("Key value duplicate, but upsert succeeded.") + else: + logger.error(f"Upsert error: {e}") except Exception as e: logger.error(f"PostgreSQL database error: {e}") print(sql) @@ -568,10 +574,10 @@ class PGGraphStorage(BaseGraphStorage): if dtype == "vertex": vertex = json.loads(v) - field = json.loads(v).get("properties") + field = vertex.get("properties") if not field: field = {} - field["label"] = PGGraphStorage._decode_graph_label(vertex["label"]) + field["label"] = PGGraphStorage._decode_graph_label(field["node_id"]) d[k] = field # convert edge from id-label->id by replacing id with node information # we only do this if the vertex was also returned in the query @@ -666,73 +672,8 @@ class PGGraphStorage(BaseGraphStorage): # otherwise return the value stripping out some common special chars return field.replace("(", "_").replace(")", "") - @staticmethod - def _wrap_query(query: str, graph_name: str, **params: str) -> str: - """ - Convert a cypher query to an Apache Age compatible - sql query by wrapping the cypher query in ag_catalog.cypher, - casting results to agtype and building a select statement - - Args: - query (str): a valid cypher query - graph_name (str): the name of the graph to query - params (dict): parameters for the query - - Returns: - str: an equivalent pgsql query - """ - - # pgsql template - template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ - {query} - $$) AS ({fields})""" - - # if there are any returned fields they must be added to the pgsql query - if "return" in query.lower(): - # parse return statement to identify returned fields - fields = ( - query.lower() - .split("return")[-1] - .split("distinct")[-1] - .split("order by")[0] - .split("skip")[0] - .split("limit")[0] - .split(",") - ) - - # raise exception if RETURN * is found as we can't resolve the fields - if "*" in [x.strip() for x in fields]: - raise ValueError( - "AGE graph does not support 'RETURN *'" - + " statements in Cypher queries" - ) - - # get pgsql formatted field names - fields = [ - PGGraphStorage._get_col_name(field, idx) - for idx, field in enumerate(fields) - ] - - # build resulting pgsql relation - fields_str = ", ".join( - [field.split(".")[-1] + " agtype" for field in fields] - ) - - # if no return statement we still need to return a single field of type agtype - else: - fields_str = "a agtype" - - select_str = "*" - - return template.format( - graph_name=graph_name, - query=query.format(**params), - fields=fields_str, - projection=select_str, - ) - async def _query( - self, query: str, readonly=True, upsert_edge=False, **params: str + self, query: str, readonly: bool = True, upsert: bool = False ) -> List[Dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an @@ -746,7 +687,7 @@ class PGGraphStorage(BaseGraphStorage): List[Dict[str, Any]]: a list of dictionaries containing the result set """ # convert cypher query to pgsql/age query - wrapped_query = self._wrap_query(query, self.graph_name, **params) + wrapped_query = query # execute the query, rolling back on an error try: @@ -758,22 +699,16 @@ class PGGraphStorage(BaseGraphStorage): graph_name=self.graph_name, ) else: - # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING) - # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future. - if upsert_edge: - data = await self.db.execute( - f"{wrapped_query};{wrapped_query};", - for_age=True, - graph_name=self.graph_name, - ) - else: - data = await self.db.execute( - wrapped_query, for_age=True, graph_name=self.graph_name - ) + data = await self.db.execute( + wrapped_query, + for_age=True, + graph_name=self.graph_name, + upsert=upsert, + ) except Exception as e: raise PGGraphQueryException( { - "message": f"Error executing graph query: {query.format(**params)}", + "message": f"Error executing graph query: {query}", "wrapped": wrapped_query, "detail": str(e), } @@ -788,77 +723,85 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') + entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists""" - params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} - single_result = (await self._query(query, **params))[0] + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + RETURN count(n) > 0 AS node_exists + $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) + + single_result = (await self._query(query))[0] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, single_result["node_exists"], ) return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) - query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) - RETURN COUNT(r) > 0 AS edge_exists""" - params = { - "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), - "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), - } - single_result = (await self._query(query, **params))[0] + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) + RETURN COUNT(r) > 0 AS edge_exists + $$) AS (edge_exists bool)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + + single_result = (await self._query(query))[0] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, single_result["edge_exists"], ) return single_result["edge_exists"] async def get_node(self, node_id: str) -> Union[dict, None]: - entity_name_label = node_id.strip('"') - query = """MATCH (n:`{label}`) RETURN n""" - params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} - record = await self._query(query, **params) + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + RETURN n + $$) AS (n agtype)""" % (self.graph_name, label) + record = await self._query(query) if record: node = record[0] node_dict = node["n"] logger.debug( "{%s}: query: {%s}, result: {%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, node_dict, ) return node_dict return None async def node_degree(self, node_id: str) -> int: - entity_name_label = node_id.strip('"') + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count""" - params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} - record = (await self._query(query, **params))[0] + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"})-[]->(x) + RETURN count(x) AS total_edge_count + $$) AS (total_edge_count integer)""" % (self.graph_name, label) + record = (await self._query(query))[0] if record: edge_count = int(record["total_edge_count"]) logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, edge_count, ) return edge_count async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') - src_degree = await self.node_degree(entity_name_label_source) - trg_degree = await self.node_degree(entity_name_label_target) + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree @@ -885,23 +828,25 @@ class PGGraphStorage(BaseGraphStorage): Returns: list: List of all relationships/edges found """ - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) - query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) - RETURN properties(r) as edge_properties - LIMIT 1""" - params = { - "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), - "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), - } - record = await self._query(query, **params) + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) + RETURN properties(r) as edge_properties + LIMIT 1 + $$) AS (edge_properties agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + record = await self._query(query) if record and record[0] and record[0]["edge_properties"]: result = record[0]["edge_properties"] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, result, ) return result @@ -911,24 +856,31 @@ class PGGraphStorage(BaseGraphStorage): Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information """ - node_label = source_node_id.strip('"') + label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - query = """MATCH (n:`{label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - params = {"label": PGGraphStorage._encode_graph_label(node_label)} - results = await self._query(query, **params) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + OPTIONAL MATCH (n)-[r]-(connected) + RETURN n, r, connected + $$) AS (n agtype, r agtype, connected agtype)""" % ( + self.graph_name, + label, + ) + + results = await self._query(query) edges = [] for record in results: source_node = record["n"] if record["n"] else None connected_node = record["connected"] if record["connected"] else None source_label = ( - source_node["label"] if source_node and source_node["label"] else None + source_node["node_id"] + if source_node and source_node["node_id"] + else None ) target_label = ( - connected_node["label"] - if connected_node and connected_node["label"] + connected_node["node_id"] + if connected_node and connected_node["node_id"] else None ) @@ -950,17 +902,21 @@ class PGGraphStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = node_id.strip('"') + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) properties = node_data - query = """MERGE (n:`{label}`) - SET n += {properties}""" - params = { - "label": PGGraphStorage._encode_graph_label(label), - "properties": PGGraphStorage._format_properties(properties), - } + query = """SELECT * FROM cypher('%s', $$ + MERGE (n:Entity {node_id: "%s"}) + SET n += %s + RETURN n + $$) AS (n agtype)""" % ( + self.graph_name, + label, + PGGraphStorage._format_properties(properties), + ) + try: - await self._query(query, readonly=False, **params) + await self._query(query, readonly=False, upsert=True) logger.debug( "Upserted node with label '{%s}' and properties: {%s}", label, @@ -986,28 +942,30 @@ class PGGraphStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge """ - source_node_label = source_node_id.strip('"') - target_node_label = target_node_id.strip('"') + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) edge_properties = edge_data - query = """MATCH (source:`{src_label}`) - WITH source - MATCH (target:`{tgt_label}`) - MERGE (source)-[r:DIRECTED]->(target) - SET r += {properties} - RETURN r""" - params = { - "src_label": PGGraphStorage._encode_graph_label(source_node_label), - "tgt_label": PGGraphStorage._encode_graph_label(target_node_label), - "properties": PGGraphStorage._format_properties(edge_properties), - } + query = """SELECT * FROM cypher('%s', $$ + MATCH (source:Entity {node_id: "%s"}) + WITH source + MATCH (target:Entity {node_id: "%s"}) + MERGE (source)-[r:DIRECTED]->(target) + SET r += %s + RETURN r + $$) AS (r agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + PGGraphStorage._format_properties(edge_properties), + ) # logger.info(f"-- inserting edge after formatted: {params}") try: - await self._query(query, readonly=False, upsert_edge=True, **params) + await self._query(query, readonly=False, upsert=True) logger.debug( "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", - source_node_label, - target_node_label, + src_label, + tgt_label, edge_properties, ) except Exception as e: From d03192a3bdfe1411e31a8961754b11e8b96415bd Mon Sep 17 00:00:00 2001 From: iridium-soda Date: Sat, 11 Jan 2025 09:27:53 +0000 Subject: [PATCH 24/38] fix: Resolve 500 error caused by missing `len()` for `LightRAG` --- lightrag/api/lollms_lightrag_server.py | 2 +- lightrag/api/ollama_lightrag_server.py | 2 +- lightrag/api/openai_lightrag_server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightrag/api/lollms_lightrag_server.py b/lightrag/api/lollms_lightrag_server.py index 8a2804a0..50a47ec1 100644 --- a/lightrag/api/lollms_lightrag_server.py +++ b/lightrag/api/lollms_lightrag_server.py @@ -376,7 +376,7 @@ def create_app(args): return InsertResponse( status="success", message="Text successfully inserted", - document_count=len(rag), + document_count=1, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/lightrag/api/ollama_lightrag_server.py b/lightrag/api/ollama_lightrag_server.py index b3140aba..66b272d8 100644 --- a/lightrag/api/ollama_lightrag_server.py +++ b/lightrag/api/ollama_lightrag_server.py @@ -375,7 +375,7 @@ def create_app(args): return InsertResponse( status="success", message="Text successfully inserted", - document_count=len(rag), + document_count=1, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py index 349c09da..d65eaa34 100644 --- a/lightrag/api/openai_lightrag_server.py +++ b/lightrag/api/openai_lightrag_server.py @@ -390,7 +390,7 @@ def create_app(args): return InsertResponse( status="success", message="Text successfully inserted", - document_count=len(rag), + document_count=1, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) From 7a56f2924629ca99cde1fe952738dfe0a701e47d Mon Sep 17 00:00:00 2001 From: iridium-soda Date: Sat, 11 Jan 2025 09:38:54 +0000 Subject: [PATCH 25/38] fix --- lightrag/api/openai_lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py index d65eaa34..349c09da 100644 --- a/lightrag/api/openai_lightrag_server.py +++ b/lightrag/api/openai_lightrag_server.py @@ -390,7 +390,7 @@ def create_app(args): return InsertResponse( status="success", message="Text successfully inserted", - document_count=1, + document_count=len(rag), ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) From fd5683f6ad48189395ac21b936ef9085a4cf077d Mon Sep 17 00:00:00 2001 From: iridium-soda Date: Sat, 11 Jan 2025 09:39:52 +0000 Subject: [PATCH 26/38] Revert "fix" This reverts commit 7a56f2924629ca99cde1fe952738dfe0a701e47d. --- lightrag/api/openai_lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py index 349c09da..d65eaa34 100644 --- a/lightrag/api/openai_lightrag_server.py +++ b/lightrag/api/openai_lightrag_server.py @@ -390,7 +390,7 @@ def create_app(args): return InsertResponse( status="success", message="Text successfully inserted", - document_count=len(rag), + document_count=1, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) From 63a71c04fd6dcab65492a9d40c5af80eb0a3517b Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Sun, 12 Jan 2025 16:56:30 +0800 Subject: [PATCH 27/38] Add known issue of Apache AGE to the readme. --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 6bc2fb3f..e874bb0a 100644 --- a/README.md +++ b/README.md @@ -366,6 +366,11 @@ For production level scenarios you will most likely want to leverage an enterpri SET search_path = ag_catalog, "$user", public; CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"')); ``` +* Known issue of the Apache AGE: The released versions got below issue: + > You might find that the properties of the nodes/edges are empty. + > It is a known issue of the release version: https://github.com/apache/age/pull/1721 + > You can Compile the AGE from source code and fix it. + ### Insert Custom KG From f3e0fb87f509a5b4ee6aaecb4087e89805f594d3 Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Sun, 12 Jan 2025 17:01:31 +0800 Subject: [PATCH 28/38] Add known issue of Apache AGE to the readme. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index e874bb0a..237e7687 100644 --- a/README.md +++ b/README.md @@ -369,9 +369,11 @@ For production level scenarios you will most likely want to leverage an enterpri * Known issue of the Apache AGE: The released versions got below issue: > You might find that the properties of the nodes/edges are empty. > It is a known issue of the release version: https://github.com/apache/age/pull/1721 + > > You can Compile the AGE from source code and fix it. + ### Insert Custom KG ```python From 5c679384671d716efc24bcddd5bade5b384f6e9a Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sun, 12 Jan 2025 12:46:23 +0100 Subject: [PATCH 29/38] Resolve 500 error caused by missing len() for LightRAG's API insert_text endpoint --- lightrag/api/lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 644e622d..d29b8b56 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -487,7 +487,7 @@ def create_app(args): return InsertResponse( status="success", message="Text successfully inserted", - document_count=len(rag), + document_count=1, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) From 7aaab219eed17b2e8790fe502b58bb129a35606d Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sun, 12 Jan 2025 12:56:08 +0100 Subject: [PATCH 30/38] Fixed awaiting insert --- lightrag/api/lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d29b8b56..5bcb149c 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -483,7 +483,7 @@ def create_app(args): ) async def insert_text(request: InsertTextRequest): try: - rag.insert(request.text) + await rag.ainsert(request.text) return InsertResponse( status="success", message="Text successfully inserted", From c01693402173d2ffdb1c6140c9f4aa06a815760b Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Sun, 12 Jan 2025 21:38:39 +0800 Subject: [PATCH 31/38] Revise the AGE implementation on get_node_edges, to align with Neo4j behavior. --- lightrag/kg/postgres_impl.py | 14 +++++++++++--- lightrag/kg/postgres_impl_test.py | 8 ++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index ccbff679..b93a345b 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -141,13 +141,16 @@ class PostgreSQLDB: await connection.execute(sql) else: await connection.execute(sql, *data.values()) - except asyncpg.exceptions.UniqueViolationError as e: + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DuplicateTableError, + ) as e: if upsert: print("Key value duplicate, but upsert succeeded.") else: logger.error(f"Upsert error: {e}") except Exception as e: - logger.error(f"PostgreSQL database error: {e}") + logger.error(f"PostgreSQL database error: {e.__class__} - {e}") print(sql) print(data) raise @@ -885,7 +888,12 @@ class PGGraphStorage(BaseGraphStorage): ) if source_label and target_label: - edges.append((source_label, target_label)) + edges.append( + ( + PGGraphStorage._decode_graph_label(source_label), + PGGraphStorage._decode_graph_label(target_label), + ) + ) return edges diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py index dc046311..274f03de 100644 --- a/lightrag/kg/postgres_impl_test.py +++ b/lightrag/kg/postgres_impl_test.py @@ -61,7 +61,7 @@ db = PostgreSQLDB( "port": 15432, "user": "rag", "password": "rag", - "database": "rag", + "database": "r1", } ) @@ -74,8 +74,12 @@ async def query_with_age(): embedding_func=None, ) graph.db = db - res = await graph.get_node('"CHRISTMAS-TIME"') + res = await graph.get_node('"A CHRISTMAS CAROL"') print("Node is: ", res) + res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG") + print("Edge is: ", res) + res = await graph.get_node_edges('"SCROOGE"') + print("Node Edges are: ", res) async def create_edge_with_age(): From 057e23c4e9dcac9cc21e12647560ae550898ad51 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Mon, 13 Jan 2025 10:13:01 +0800 Subject: [PATCH 32/38] Update __init__.py --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index b8037813..7a26a282 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.0" +__version__ = "1.1.1" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 867475fd1f394dc69da1da47298c7af9ab5682d5 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Mon, 13 Jan 2025 10:28:19 +0800 Subject: [PATCH 33/38] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e8401a3d..90c3ec04 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News +- [x] [2025.01.13]🎯📢Our team has launched [MiniRAG](https://github.com/HKUDS/MiniRAG) for small models. - [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage). - [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise. From f28b90b2b397400361caa874e0ddc9db31c0aeb1 Mon Sep 17 00:00:00 2001 From: bingo Date: Mon, 13 Jan 2025 07:06:01 +0000 Subject: [PATCH 34/38] 1. add os env NEO4J_MAX_CONNECTION_POOL_SIZE to for neo4j ; 2. fix https://github.com/HKUDS/LightRAG/issues/580 issue for mongoDB document 16MB limit. --- lightrag/kg/mongo_impl.py | 32 +++++++++++++++++++++++++++----- lightrag/kg/neo4j_impl.py | 3 ++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 61222357..5aab9c07 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -2,7 +2,7 @@ import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass from pymongo import MongoClient - +from typing import Union from lightrag.utils import logger from lightrag.base import BaseKVStorage @@ -41,11 +41,33 @@ class MongoKVStorage(BaseKVStorage): return set([s for s in data if s not in existing_ids]) async def upsert(self, data: dict[str, dict]): - for k, v in tqdm_async(data.items(), desc="Upserting"): - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) - data[k]["_id"] = k + if self.namespace == "llm_response_cache": + for mode, items in data.items(): + for k, v in tqdm_async(items.items(), desc="Upserting"): + key = f"{mode}_{k}" + result = self._data.update_one({"_id": key}, {"$setOnInsert": v}, upsert=True) + if result.upserted_id: + logger.debug(f"\nInserted new document with key: {key}") + data[mode][k]["_id"] = key + else: + for k, v in tqdm_async(data.items(), desc="Upserting"): + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + data[k]["_id"] = k return data - + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + if "llm_response_cache" == self.namespace: + res = {} + v = self._data.find_one({"_id": mode+"_"+id}) + if v: + res[id] = v + print(f"find one by:{id}") + return res + else: + return None + else: + return None + async def drop(self): """ """ pass diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 884fcb40..96247c05 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -39,6 +39,7 @@ class Neo4JStorage(BaseGraphStorage): URI = os.environ["NEO4J_URI"] USERNAME = os.environ["NEO4J_USERNAME"] PASSWORD = os.environ["NEO4J_PASSWORD"] + MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) DATABASE = os.environ.get( "NEO4J_DATABASE" ) # If this param is None, the home database will be used. If it is not None, the specified database will be used. @@ -47,7 +48,7 @@ class Neo4JStorage(BaseGraphStorage): URI, auth=(USERNAME, PASSWORD) ) _database_name = "home database" if DATABASE is None else f"database {DATABASE}" - with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver: + with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD), max_connection_pool_size=MAX_CONNECTION_POOL_SIZE) as _sync_driver: try: with _sync_driver.session(database=DATABASE) as session: try: From 1984da0fd6ee17d3f187a13e423ce13aaac9945f Mon Sep 17 00:00:00 2001 From: bingo Date: Mon, 13 Jan 2025 07:27:30 +0000 Subject: [PATCH 35/38] add logger.debug for mongo_impl get_by_mode_and_id() --- lightrag/kg/mongo_impl.py | 12 +++++++----- lightrag/kg/neo4j_impl.py | 6 +++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 5aab9c07..fbbae8c2 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -45,7 +45,9 @@ class MongoKVStorage(BaseKVStorage): for mode, items in data.items(): for k, v in tqdm_async(items.items(), desc="Upserting"): key = f"{mode}_{k}" - result = self._data.update_one({"_id": key}, {"$setOnInsert": v}, upsert=True) + result = self._data.update_one( + {"_id": key}, {"$setOnInsert": v}, upsert=True + ) if result.upserted_id: logger.debug(f"\nInserted new document with key: {key}") data[mode][k]["_id"] = key @@ -54,20 +56,20 @@ class MongoKVStorage(BaseKVStorage): self._data.update_one({"_id": k}, {"$set": v}, upsert=True) data[k]["_id"] = k return data - + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: if "llm_response_cache" == self.namespace: res = {} - v = self._data.find_one({"_id": mode+"_"+id}) + v = self._data.find_one({"_id": mode + "_" + id}) if v: res[id] = v - print(f"find one by:{id}") + logger.debug(f"llm_response_cache find one by:{id}") return res else: return None else: return None - + async def drop(self): """ """ pass diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 96247c05..8c2afb5d 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -48,7 +48,11 @@ class Neo4JStorage(BaseGraphStorage): URI, auth=(USERNAME, PASSWORD) ) _database_name = "home database" if DATABASE is None else f"database {DATABASE}" - with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD), max_connection_pool_size=MAX_CONNECTION_POOL_SIZE) as _sync_driver: + with GraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, + ) as _sync_driver: try: with _sync_driver.session(database=DATABASE) as session: try: From c3aba5423f995be628df8dbcb22702d00c9476d9 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Tue, 14 Jan 2025 23:08:39 +0100 Subject: [PATCH 36/38] Added more file types support --- lightrag/api/lightrag_server.py | 300 +++++++++++++++++++++++++++----- lightrag/api/requirements.txt | 1 + 2 files changed, 260 insertions(+), 41 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5bcb149c..d9f7bf06 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -9,7 +9,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding from lightrag.utils import EmbeddingFunc -from typing import Optional, List +from typing import Optional, List, Union from enum import Enum from pathlib import Path import shutil @@ -22,6 +22,7 @@ from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN +import pipmaster as pm def get_default_host(binding_type: str) -> str: @@ -174,7 +175,7 @@ def parse_args(): class DocumentManager: """Handles document operations and tracking""" - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): + def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md", ".pdf", ".docx", ".pptx")): self.input_dir = Path(input_dir) self.supported_extensions = supported_extensions self.indexed_files = set() @@ -289,7 +290,7 @@ def create_app(args): + "(With authentication)" if api_key else "", - version="1.0.1", + version="1.0.2", openapi_tags=[{"name": "api"}], ) @@ -356,6 +357,85 @@ def create_app(args): ), ) + + + async def index_file(file_path: Union[str, Path]) -> None: + """ Index all files inside the folder with support for multiple file formats + + Args: + file_path: Path to the file to be indexed (str or Path object) + + Raises: + ValueError: If file format is not supported + FileNotFoundError: If file doesn't exist + """ + if not pm.is_installed("aiofiles"): + pm.install("aiofiles") + import aiofiles + + + # Convert to Path object if string + file_path = Path(file_path) + + # Check if file exists + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + content = "" + # Get file extension in lowercase + ext = file_path.suffix.lower() + + match ext: + case ".txt" | ".md": + # Text files handling + async with aiofiles.open(file_path, "r", encoding="utf-8") as f: + content = await f.read() + + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from pypdf2 import PdfReader + # PDF handling + reader = PdfReader(str(file_path)) + content = "" + for page in reader.pages: + content += page.extract_text() + "\n" + + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + + # Word document handling + doc = Document(file_path) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + # PowerPoint handling + prs = Presentation(file_path) + content = "" + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + + case _: + raise ValueError(f"Unsupported file format: {ext}") + + # Insert content into RAG system + if content: + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + logging.info(f"Successfully indexed file: {file_path}") + else: + logging.warning(f"No content extracted from file: {file_path}") + + + + @app.on_event("startup") async def startup_event(): """Index all files in input directory during startup""" @@ -363,13 +443,7 @@ def create_app(args): new_files = doc_manager.scan_directory() for file_path in new_files: try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") + await index_file(file_path) except Exception as e: trace_exception(e) logging.error(f"Error indexing file {file_path}: {str(e)}") @@ -388,11 +462,8 @@ def create_app(args): for file_path in new_files: try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 + await index_file(file_path) + indexed_count += 1 except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") @@ -419,10 +490,7 @@ def create_app(args): shutil.copyfileobj(file.file, buffer) # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) + await index_file(file_path) return { "status": "success", @@ -491,69 +559,219 @@ def create_app(args): ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post( "/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)], ) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): + """Insert a file directly into the RAG system + + Args: + file: Uploaded file + description: Optional description of the file + + Returns: + InsertResponse: Status of the insertion operation + + Raises: + HTTPException: For unsupported file types or processing errors + """ try: - content = await file.read() + content = "" + # Get file extension in lowercase + ext = Path(file.filename).suffix.lower() + + match ext: + case ".txt" | ".md": + # Text files handling + text_content = await file.read() + content = text_content.decode("utf-8") + + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from pypdf2 import PdfReader + from io import BytesIO + + # Read PDF from memory + pdf_content = await file.read() + pdf_file = BytesIO(pdf_content) + reader = PdfReader(pdf_file) + content = "" + for page in reader.pages: + content += page.extract_text() + "\n" + + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO + + # Read DOCX from memory + docx_content = await file.read() + docx_file = BytesIO(docx_content) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + from io import BytesIO + + # Read PPTX from memory + pptx_content = await file.read() + pptx_file = BytesIO(pptx_content) + prs = Presentation(pptx_file) + content = "" + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + + case _: + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) + # Insert content into RAG system + if content: + # Add description if provided + if description: + content = f"{description}\n\n{content}" + + await rag.ainsert(content) + logging.info(f"Successfully indexed file: {file.filename}") + + return InsertResponse( + status="success", + message=f"File '{file.filename}' successfully inserted", + document_count=1, + ) else: raise HTTPException( status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", + detail="No content could be extracted from the file", ) - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) except UnicodeDecodeError: raise HTTPException(status_code=400, detail="File encoding not supported") except Exception as e: + logging.error(f"Error processing file {file.filename}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - @app.post( "/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)], ) async def insert_batch(files: List[UploadFile] = File(...)): + """Process multiple files in batch mode + + Args: + files: List of files to process + + Returns: + InsertResponse: Status of the batch insertion operation + + Raises: + HTTPException: For processing errors + """ try: inserted_count = 0 failed_files = [] for file in files: try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) + content = "" + ext = Path(file.filename).suffix.lower() + + match ext: + case ".txt" | ".md": + text_content = await file.read() + content = text_content.decode("utf-8") + + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from pypdf2 import PdfReader + from io import BytesIO + + pdf_content = await file.read() + pdf_file = BytesIO(pdf_content) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO + + docx_content = await file.read() + docx_file = BytesIO(docx_content) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + from io import BytesIO + + pptx_content = await file.read() + pptx_file = BytesIO(pptx_content) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + + case _: + failed_files.append(f"{file.filename} (unsupported type)") + continue + + if content: + await rag.ainsert(content) inserted_count += 1 + logging.info(f"Successfully indexed file: {file.filename}") else: - failed_files.append(f"{file.filename} (unsupported type)") + failed_files.append(f"{file.filename} (no content extracted)") + + except UnicodeDecodeError: + failed_files.append(f"{file.filename} (encoding error)") except Exception as e: failed_files.append(f"{file.filename} ({str(e)})") + logging.error(f"Error processing file {file.filename}: {str(e)}") - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" + # Prepare status message + if inserted_count == len(files): + status = "success" + status_message = f"Successfully inserted all {inserted_count} documents" + elif inserted_count > 0: + status = "partial_success" + status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + else: + status = "failure" + status_message = "No documents were successfully inserted" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", + status=status, message=status_message, - document_count=len(files), + document_count=inserted_count, ) + except Exception as e: + logging.error(f"Batch processing error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) + @app.delete( "/documents", response_model=InsertResponse, diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index 221d7f40..b8fc41b2 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -15,3 +15,4 @@ torch tqdm transformers uvicorn +pipmaster \ No newline at end of file From 29661c92da1a9828e320f6238deeb2861d61532f Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Tue, 14 Jan 2025 23:11:23 +0100 Subject: [PATCH 37/38] fixed linting --- lightrag/api/lightrag_server.py | 94 +++++++++++++++++---------------- lightrag/api/requirements.txt | 2 +- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d9f7bf06..0d154b38 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -175,7 +175,11 @@ def parse_args(): class DocumentManager: """Handles document operations and tracking""" - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md", ".pdf", ".docx", ".pptx")): + def __init__( + self, + input_dir: str, + supported_extensions: tuple = (".txt", ".md", ".pdf", ".docx", ".pptx"), + ): self.input_dir = Path(input_dir) self.supported_extensions = supported_extensions self.indexed_files = set() @@ -357,26 +361,22 @@ def create_app(args): ), ) - - async def index_file(file_path: Union[str, Path]) -> None: - """ Index all files inside the folder with support for multiple file formats - + """Index all files inside the folder with support for multiple file formats + Args: file_path: Path to the file to be indexed (str or Path object) - + Raises: ValueError: If file format is not supported FileNotFoundError: If file doesn't exist """ if not pm.is_installed("aiofiles"): pm.install("aiofiles") - import aiofiles - - + # Convert to Path object if string file_path = Path(file_path) - + # Check if file exists if not file_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") @@ -384,23 +384,24 @@ def create_app(args): content = "" # Get file extension in lowercase ext = file_path.suffix.lower() - + match ext: case ".txt" | ".md": # Text files handling async with aiofiles.open(file_path, "r", encoding="utf-8") as f: content = await f.read() - + case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") from pypdf2 import PdfReader + # PDF handling reader = PdfReader(str(file_path)) content = "" for page in reader.pages: content += page.extract_text() + "\n" - + case ".docx": if not pm.is_installed("docx"): pm.install("docx") @@ -409,11 +410,12 @@ def create_app(args): # Word document handling doc = Document(file_path) content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) - + case ".pptx": if not pm.is_installed("pptx"): pm.install("pptx") from pptx import Presentation + # PowerPoint handling prs = Presentation(file_path) content = "" @@ -421,7 +423,7 @@ def create_app(args): for shape in slide.shapes: if hasattr(shape, "text"): content += shape.text + "\n" - + case _: raise ValueError(f"Unsupported file format: {ext}") @@ -433,9 +435,6 @@ def create_app(args): else: logging.warning(f"No content extracted from file: {file_path}") - - - @app.on_event("startup") async def startup_event(): """Index all files in input directory during startup""" @@ -559,6 +558,7 @@ def create_app(args): ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @app.post( "/documents/file", response_model=InsertResponse, @@ -566,14 +566,14 @@ def create_app(args): ) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): """Insert a file directly into the RAG system - + Args: file: Uploaded file description: Optional description of the file - + Returns: InsertResponse: Status of the insertion operation - + Raises: HTTPException: For unsupported file types or processing errors """ @@ -581,19 +581,19 @@ def create_app(args): content = "" # Get file extension in lowercase ext = Path(file.filename).suffix.lower() - + match ext: case ".txt" | ".md": # Text files handling text_content = await file.read() content = text_content.decode("utf-8") - + case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") from pypdf2 import PdfReader from io import BytesIO - + # Read PDF from memory pdf_content = await file.read() pdf_file = BytesIO(pdf_content) @@ -601,25 +601,27 @@ def create_app(args): content = "" for page in reader.pages: content += page.extract_text() + "\n" - + case ".docx": if not pm.is_installed("docx"): pm.install("docx") from docx import Document from io import BytesIO - + # Read DOCX from memory docx_content = await file.read() docx_file = BytesIO(docx_content) doc = Document(docx_file) - content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) - + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) + case ".pptx": if not pm.is_installed("pptx"): pm.install("pptx") from pptx import Presentation from io import BytesIO - + # Read PPTX from memory pptx_content = await file.read() pptx_file = BytesIO(pptx_content) @@ -629,7 +631,7 @@ def create_app(args): for shape in slide.shapes: if hasattr(shape, "text"): content += shape.text + "\n" - + case _: raise HTTPException( status_code=400, @@ -641,10 +643,10 @@ def create_app(args): # Add description if provided if description: content = f"{description}\n\n{content}" - + await rag.ainsert(content) logging.info(f"Successfully indexed file: {file.filename}") - + return InsertResponse( status="success", message=f"File '{file.filename}' successfully inserted", @@ -661,6 +663,7 @@ def create_app(args): except Exception as e: logging.error(f"Error processing file {file.filename}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) + @app.post( "/documents/batch", response_model=InsertResponse, @@ -668,13 +671,13 @@ def create_app(args): ) async def insert_batch(files: List[UploadFile] = File(...)): """Process multiple files in batch mode - + Args: files: List of files to process - + Returns: InsertResponse: Status of the batch insertion operation - + Raises: HTTPException: For processing errors """ @@ -686,41 +689,43 @@ def create_app(args): try: content = "" ext = Path(file.filename).suffix.lower() - + match ext: case ".txt" | ".md": text_content = await file.read() content = text_content.decode("utf-8") - + case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") from pypdf2 import PdfReader from io import BytesIO - + pdf_content = await file.read() pdf_file = BytesIO(pdf_content) reader = PdfReader(pdf_file) for page in reader.pages: content += page.extract_text() + "\n" - + case ".docx": if not pm.is_installed("docx"): pm.install("docx") from docx import Document from io import BytesIO - + docx_content = await file.read() docx_file = BytesIO(docx_content) doc = Document(docx_file) - content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) - + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) + case ".pptx": if not pm.is_installed("pptx"): pm.install("pptx") from pptx import Presentation from io import BytesIO - + pptx_content = await file.read() pptx_file = BytesIO(pptx_content) prs = Presentation(pptx_file) @@ -728,7 +733,7 @@ def create_app(args): for shape in slide.shapes: if hasattr(shape, "text"): content += shape.text + "\n" - + case _: failed_files.append(f"{file.filename} (unsupported type)") continue @@ -771,7 +776,6 @@ def create_app(args): logging.error(f"Batch processing error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - @app.delete( "/documents", response_model=InsertResponse, diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index b8fc41b2..9154809c 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -7,6 +7,7 @@ nest_asyncio numpy ollama openai +pipmaster python-dotenv python-multipart tenacity @@ -15,4 +16,3 @@ torch tqdm transformers uvicorn -pipmaster \ No newline at end of file From 8f0196f6b9f273333fbbf21464be4c593b199c8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=9C=A8Data=20Intelligence=20Lab=40HKU=E2=9C=A8?= <118165258+HKUDS@users.noreply.github.com> Date: Wed, 15 Jan 2025 13:08:07 +0800 Subject: [PATCH 38/38] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 90c3ec04..71248056 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News -- [x] [2025.01.13]🎯📢Our team has launched [MiniRAG](https://github.com/HKUDS/MiniRAG) for small models. +- [x] [2025.01.13]🎯📢Our team has released [MiniRAG](https://github.com/HKUDS/MiniRAG) making RAG simpler with small models. - [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage). - [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.