diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cbe49da2..47d64ac0 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -45,6 +45,7 @@ from .storage import ( from .prompt import GRAPH_FIELD_SEP + # future KG integrations # from .kg.ArangoDB_impl import ( @@ -167,7 +168,7 @@ class LightRAG: # LLM llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' + llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = 32768 llm_model_max_async: int = 16 llm_model_kwargs: dict = field(default_factory=dict) @@ -267,7 +268,7 @@ class LightRAG: self.llm_model_func, hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -313,15 +314,16 @@ class LightRAG: "JsonDocStatusStorage": JsonDocStatusStorage, } - def insert(self, string_or_strings): + def insert(self, string_or_strings, split_by_character=None): loop = always_get_an_event_loop() - return loop.run_until_complete(self.ainsert(string_or_strings)) + return loop.run_until_complete(self.ainsert(string_or_strings, split_by_character)) - async def ainsert(self, string_or_strings): + async def ainsert(self, string_or_strings, split_by_character): """Insert documents with checkpoint support Args: string_or_strings: Single document string or list of document strings + split_by_character: if split_by_character is not None, split the string by character """ if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] @@ -355,10 +357,10 @@ class LightRAG: # Process documents in batches batch_size = self.addon_params.get("insert_batch_size", 10) for i in range(0, len(new_docs), batch_size): - batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + batch_docs = dict(list(new_docs.items())[i: i + batch_size]) for doc_id, doc in tqdm_async( - batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}" + batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}" ): try: # Update status to processing @@ -379,6 +381,7 @@ class LightRAG: } for dp in chunking_by_token_size( doc["content"], + split_by_character=split_by_character, overlap_token_size=self.chunk_overlap_token_size, max_token_size=self.chunk_token_size, tiktoken_model=self.tiktoken_model_name, @@ -545,7 +548,7 @@ class LightRAG: # Check if nodes exist in the knowledge graph for need_insert_id in [src_id, tgt_id]: if not ( - await self.chunk_entity_relation_graph.has_node(need_insert_id) + await self.chunk_entity_relation_graph.has_node(need_insert_id) ): await self.chunk_entity_relation_graph.upsert_node( need_insert_id, @@ -594,9 +597,9 @@ class LightRAG: "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], } for dp in all_relationships_data } @@ -621,7 +624,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -637,7 +640,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -656,7 +659,7 @@ class LightRAG: asdict(self), hashing_kv=self.llm_response_cache if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -897,7 +900,7 @@ class LightRAG: dp for dp in self.entities_vdb.client_storage["data"] if chunk_id - in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) ] if entities_with_chunk: logger.error( @@ -909,7 +912,7 @@ class LightRAG: dp for dp in self.relationships_vdb.client_storage["data"] if chunk_id - in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) + in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP) ] if relations_with_chunk: logger.error( @@ -926,7 +929,7 @@ class LightRAG: return asyncio.run(self.adelete_by_doc_id(doc_id)) async def get_entity_info( - self, entity_name: str, include_vector_data: bool = False + self, entity_name: str, include_vector_data: bool = False ): """Get detailed information of an entity @@ -977,7 +980,7 @@ class LightRAG: tracemalloc.stop() async def get_relation_info( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False + self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): """Get detailed information of a relationship @@ -1019,7 +1022,7 @@ class LightRAG: return result def get_relation_info_sync( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False + self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): """Synchronous version of getting relationship information diff --git a/lightrag/operate.py b/lightrag/operate.py index b2c4d215..e8f0df65 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -34,30 +34,52 @@ import time def chunking_by_token_size( - content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" + content: str, split_by_character=None, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" ): tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) results = [] - for index, start in enumerate( - range(0, len(tokens), max_token_size - overlap_token_size) - ): - chunk_content = decode_tokens_by_tiktoken( - tokens[start : start + max_token_size], model_name=tiktoken_model - ) - results.append( - { - "tokens": min(max_token_size, len(tokens) - start), - "content": chunk_content.strip(), - "chunk_order_index": index, - } - ) + if split_by_character: + raw_chunks = content.split(split_by_character) + new_chunks = [] + for chunk in raw_chunks: + _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model) + if len(_tokens) > max_token_size: + for start in range(0, len(_tokens), max_token_size - overlap_token_size): + chunk_content = decode_tokens_by_tiktoken( + _tokens[start: start + max_token_size], model_name=tiktoken_model + ) + new_chunks.append((min(max_token_size, len(_tokens) - start), chunk_content)) + else: + new_chunks.append((len(_tokens), chunk)) + for index, (_len, chunk) in enumerate(new_chunks): + results.append( + { + "tokens": _len, + "content": chunk.strip(), + "chunk_order_index": index, + } + ) + else: + for index, start in enumerate( + range(0, len(tokens), max_token_size - overlap_token_size) + ): + chunk_content = decode_tokens_by_tiktoken( + tokens[start: start + max_token_size], model_name=tiktoken_model + ) + results.append( + { + "tokens": min(max_token_size, len(tokens) - start), + "content": chunk_content.strip(), + "chunk_order_index": index, + } + ) return results async def _handle_entity_relation_summary( - entity_or_relation_name: str, - description: str, - global_config: dict, + entity_or_relation_name: str, + description: str, + global_config: dict, ) -> str: use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] @@ -86,8 +108,8 @@ async def _handle_entity_relation_summary( async def _handle_single_entity_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], + chunk_key: str, ): if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None @@ -107,8 +129,8 @@ async def _handle_single_entity_extraction( async def _handle_single_relationship_extraction( - record_attributes: list[str], - chunk_key: str, + record_attributes: list[str], + chunk_key: str, ): if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None @@ -134,10 +156,10 @@ async def _handle_single_relationship_extraction( async def _merge_nodes_then_upsert( - entity_name: str, - nodes_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + entity_name: str, + nodes_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + global_config: dict, ): already_entity_types = [] already_source_ids = [] @@ -181,11 +203,11 @@ async def _merge_nodes_then_upsert( async def _merge_edges_then_upsert( - src_id: str, - tgt_id: str, - edges_data: list[dict], - knowledge_graph_inst: BaseGraphStorage, - global_config: dict, + src_id: str, + tgt_id: str, + edges_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + global_config: dict, ): already_weights = [] already_source_ids = [] @@ -248,12 +270,12 @@ async def _merge_edges_then_upsert( async def extract_entities( - chunks: dict[str, TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, - entity_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - global_config: dict, - llm_response_cache: BaseKVStorage = None, + chunks: dict[str, TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + global_config: dict, + llm_response_cache: BaseKVStorage = None, ) -> Union[BaseGraphStorage, None]: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -305,13 +327,13 @@ async def extract_entities( already_relations = 0 async def _user_llm_func_with_cache( - input_text: str, history_messages: list[dict[str, str]] = None + input_text: str, history_messages: list[dict[str, str]] = None ) -> str: if enable_llm_cache_for_entity_extract and llm_response_cache: need_to_restore = False if ( - global_config["embedding_cache_config"] - and global_config["embedding_cache_config"]["enabled"] + global_config["embedding_cache_config"] + and global_config["embedding_cache_config"]["enabled"] ): new_config = global_config.copy() new_config["embedding_cache_config"] = None @@ -413,7 +435,7 @@ async def extract_entities( already_relations += len(maybe_edges) now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) - ] + ] print( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", end="", @@ -423,10 +445,10 @@ async def extract_entities( results = [] for result in tqdm_async( - asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), - total=len(ordered_chunks), - desc="Extracting entities from chunks", - unit="chunk", + asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), + total=len(ordered_chunks), + desc="Extracting entities from chunks", + unit="chunk", ): results.append(await result) @@ -440,32 +462,32 @@ async def extract_entities( logger.info("Inserting entities into storage...") all_entities_data = [] for result in tqdm_async( - asyncio.as_completed( - [ - _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ), - total=len(maybe_nodes), - desc="Inserting entities", - unit="entity", + asyncio.as_completed( + [ + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ), + total=len(maybe_nodes), + desc="Inserting entities", + unit="entity", ): all_entities_data.append(await result) logger.info("Inserting relationships into storage...") all_relationships_data = [] for result in tqdm_async( - asyncio.as_completed( - [ - _merge_edges_then_upsert( - k[0], k[1], v, knowledge_graph_inst, global_config - ) - for k, v in maybe_edges.items() - ] - ), - total=len(maybe_edges), - desc="Inserting relationships", - unit="relationship", + asyncio.as_completed( + [ + _merge_edges_then_upsert( + k[0], k[1], v, knowledge_graph_inst, global_config + ) + for k, v in maybe_edges.items() + ] + ), + total=len(maybe_edges), + desc="Inserting relationships", + unit="relationship", ): all_relationships_data.append(await result) @@ -496,9 +518,9 @@ async def extract_entities( "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], "metadata": { "created_at": dp.get("metadata", {}).get("created_at", time.time()) }, @@ -511,14 +533,14 @@ async def extract_entities( async def kg_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: # Handle cache use_model_func = global_config["llm_model_func"] @@ -638,12 +660,12 @@ async def kg_query( async def _build_query_context( - query: list, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + query: list, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): # ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" # hl_entities_context, hl_relations_context, hl_text_units_context = "", "", "" @@ -696,9 +718,9 @@ async def _build_query_context( query_param, ) if ( - hl_entities_context == "" - and hl_relations_context == "" - and hl_text_units_context == "" + hl_entities_context == "" + and hl_relations_context == "" + and hl_text_units_context == "" ): logger.warn("No high level context found. Switching to local mode.") query_param.mode = "local" @@ -737,11 +759,11 @@ async def _build_query_context( async def _get_node_data( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): # get similar entities results = await entities_vdb.query(query, top_k=query_param.top_k) @@ -828,10 +850,10 @@ async def _get_node_data( async def _find_most_related_text_unit_from_entities( - node_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage[TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -871,8 +893,8 @@ async def _find_most_related_text_unit_from_entities( if this_edges: for e in this_edges: if ( - e[1] in all_one_hop_text_units_lookup - and c_id in all_one_hop_text_units_lookup[e[1]] + e[1] in all_one_hop_text_units_lookup + and c_id in all_one_hop_text_units_lookup[e[1]] ): all_text_units_lookup[c_id]["relation_counts"] += 1 @@ -902,9 +924,9 @@ async def _find_most_related_text_unit_from_entities( async def _find_most_related_edges_from_entities( - node_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + node_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, ): all_related_edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] @@ -942,11 +964,11 @@ async def _find_most_related_edges_from_entities( async def _get_edge_data( - keywords, - knowledge_graph_inst: BaseGraphStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, + keywords, + knowledge_graph_inst: BaseGraphStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, ): results = await relationships_vdb.query(keywords, top_k=query_param.top_k) @@ -1044,9 +1066,9 @@ async def _get_edge_data( async def _find_most_related_entities_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, ): entity_names = [] seen = set() @@ -1081,10 +1103,10 @@ async def _find_most_related_entities_from_relationships( async def _find_related_text_unit_from_relationships( - edge_datas: list[dict], - query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], - knowledge_graph_inst: BaseGraphStorage, + edge_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage[TextChunkSchema], + knowledge_graph_inst: BaseGraphStorage, ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -1150,12 +1172,12 @@ def combine_contexts(entities, relationships, sources): async def naive_query( - query, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ): # Handle cache use_model_func = global_config["llm_model_func"] @@ -1213,7 +1235,7 @@ async def naive_query( if len(response) > len(sys_prompt): response = ( - response[len(sys_prompt) :] + response[len(sys_prompt):] .replace(sys_prompt, "") .replace("user", "") .replace("model", "") @@ -1241,15 +1263,15 @@ async def naive_query( async def mix_kg_vector_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + query, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: """ Hybrid retrieval implementation combining knowledge graph and vector search. @@ -1274,7 +1296,7 @@ async def mix_kg_vector_query( # Reuse keyword extraction logic from kg_query example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len( - PROMPTS["keywords_extraction_examples"] + PROMPTS["keywords_extraction_examples"] ): examples = "\n".join( PROMPTS["keywords_extraction_examples"][: int(example_number)]