This commit is contained in:
zrguo
2025-03-17 23:36:00 +08:00
parent bf18a5406e
commit 6115f60072
2 changed files with 73 additions and 45 deletions

View File

@@ -563,7 +563,9 @@ class LightRAG:
""" """
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
loop.run_until_complete( loop.run_until_complete(
self.ainsert(input, split_by_character, split_by_character_only, ids, file_paths) self.ainsert(
input, split_by_character, split_by_character_only, ids, file_paths
)
) )
async def ainsert( async def ainsert(
@@ -659,7 +661,10 @@ class LightRAG:
await self._insert_done() await self._insert_done()
async def apipeline_enqueue_documents( async def apipeline_enqueue_documents(
self, input: str | list[str], ids: list[str] | None = None, file_paths: str | list[str] | None = None self,
input: str | list[str],
ids: list[str] | None = None,
file_paths: str | list[str] | None = None,
) -> None: ) -> None:
""" """
Pipeline for Processing Documents Pipeline for Processing Documents
@@ -669,7 +674,7 @@ class LightRAG:
3. Generate document initial status 3. Generate document initial status
4. Filter out already processed documents 4. Filter out already processed documents
5. Enqueue document in status 5. Enqueue document in status
Args: Args:
input: Single document string or list of document strings input: Single document string or list of document strings
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
@@ -681,13 +686,15 @@ class LightRAG:
ids = [ids] ids = [ids]
if isinstance(file_paths, str): if isinstance(file_paths, str):
file_paths = [file_paths] file_paths = [file_paths]
# If file_paths is provided, ensure it matches the number of documents # If file_paths is provided, ensure it matches the number of documents
if file_paths is not None: if file_paths is not None:
if isinstance(file_paths, str): if isinstance(file_paths, str):
file_paths = [file_paths] file_paths = [file_paths]
if len(file_paths) != len(input): if len(file_paths) != len(input):
raise ValueError("Number of file paths must match the number of documents") raise ValueError(
"Number of file paths must match the number of documents"
)
else: else:
# If no file paths provided, use placeholder # If no file paths provided, use placeholder
file_paths = ["unknown_source"] * len(input) file_paths = ["unknown_source"] * len(input)
@@ -703,22 +710,30 @@ class LightRAG:
raise ValueError("IDs must be unique") raise ValueError("IDs must be unique")
# Generate contents dict of IDs provided by user and documents # Generate contents dict of IDs provided by user and documents
contents = {id_: {"content": doc, "file_path": path} contents = {
for id_, doc, path in zip(ids, input, file_paths)} id_: {"content": doc, "file_path": path}
for id_, doc, path in zip(ids, input, file_paths)
}
else: else:
# Clean input text and remove duplicates # Clean input text and remove duplicates
cleaned_input = [(clean_text(doc), path) for doc, path in zip(input, file_paths)] cleaned_input = [
(clean_text(doc), path) for doc, path in zip(input, file_paths)
]
unique_content_with_paths = {} unique_content_with_paths = {}
# Keep track of unique content and their paths # Keep track of unique content and their paths
for content, path in cleaned_input: for content, path in cleaned_input:
if content not in unique_content_with_paths: if content not in unique_content_with_paths:
unique_content_with_paths[content] = path unique_content_with_paths[content] = path
# Generate contents dict of MD5 hash IDs and documents with paths # Generate contents dict of MD5 hash IDs and documents with paths
contents = {compute_mdhash_id(content, prefix="doc-"): contents = {
{"content": content, "file_path": path} compute_mdhash_id(content, prefix="doc-"): {
for content, path in unique_content_with_paths.items()} "content": content,
"file_path": path,
}
for content, path in unique_content_with_paths.items()
}
# 2. Remove duplicate contents # 2. Remove duplicate contents
unique_contents = {} unique_contents = {}
@@ -727,10 +742,12 @@ class LightRAG:
file_path = content_data["file_path"] file_path = content_data["file_path"]
if content not in unique_contents: if content not in unique_contents:
unique_contents[content] = (id_, file_path) unique_contents[content] = (id_, file_path)
# Reconstruct contents with unique content # Reconstruct contents with unique content
contents = {id_: {"content": content, "file_path": file_path} contents = {
for content, (id_, file_path) in unique_contents.items()} id_: {"content": content, "file_path": file_path}
for content, (id_, file_path) in unique_contents.items()
}
# 3. Generate document initial status # 3. Generate document initial status
new_docs: dict[str, Any] = { new_docs: dict[str, Any] = {
@@ -741,7 +758,9 @@ class LightRAG:
"content_length": len(content_data["content"]), "content_length": len(content_data["content"]),
"created_at": datetime.now().isoformat(), "created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
"file_path": content_data["file_path"], # Store file path in document status "file_path": content_data[
"file_path"
], # Store file path in document status
} }
for id_, content_data in contents.items() for id_, content_data in contents.items()
} }
@@ -880,7 +899,7 @@ class LightRAG:
try: try:
# Get file path from status document # Get file path from status document
file_path = getattr(status_doc, "file_path", "unknown_source") file_path = getattr(status_doc, "file_path", "unknown_source")
# Generate chunks from document # Generate chunks from document
chunks: dict[str, Any] = { chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): { compute_mdhash_id(dp["content"], prefix="chunk-"): {
@@ -897,7 +916,7 @@ class LightRAG:
self.tiktoken_model_name, self.tiktoken_model_name,
) )
} }
# Process document (text chunks and full docs) in parallel # Process document (text chunks and full docs) in parallel
# Create tasks with references for potential cancellation # Create tasks with references for potential cancellation
doc_status_task = asyncio.create_task( doc_status_task = asyncio.create_task(
@@ -1109,7 +1128,10 @@ class LightRAG:
loop.run_until_complete(self.ainsert_custom_kg(custom_kg, full_doc_id)) loop.run_until_complete(self.ainsert_custom_kg(custom_kg, full_doc_id))
async def ainsert_custom_kg( async def ainsert_custom_kg(
self, custom_kg: dict[str, Any], full_doc_id: str = None, file_path: str = "custom_kg" self,
custom_kg: dict[str, Any],
full_doc_id: str = None,
file_path: str = "custom_kg",
) -> None: ) -> None:
update_storage = False update_storage = False
try: try:
@@ -3125,4 +3147,3 @@ class LightRAG:
] ]
] ]
) )

View File

@@ -224,7 +224,9 @@ async def _merge_nodes_then_upsert(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
) )
already_file_paths.extend( already_file_paths.extend(
split_string_by_multi_markers(already_node["metadata"]["file_path"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(
already_node["metadata"]["file_path"], [GRAPH_FIELD_SEP]
)
) )
already_description.append(already_node["description"]) already_description.append(already_node["description"])
@@ -290,7 +292,7 @@ async def _merge_edges_then_upsert(
already_edge["source_id"], [GRAPH_FIELD_SEP] already_edge["source_id"], [GRAPH_FIELD_SEP]
) )
) )
# Get file_path with empty string default if missing or None # Get file_path with empty string default if missing or None
if already_edge.get("file_path") is not None: if already_edge.get("file_path") is not None:
already_file_paths.extend( already_file_paths.extend(
@@ -336,7 +338,14 @@ async def _merge_edges_then_upsert(
) )
) )
file_path = GRAPH_FIELD_SEP.join( file_path = GRAPH_FIELD_SEP.join(
set([dp["metadata"]["file_path"] for dp in edges_data if dp.get("metadata", {}).get("file_path")] + already_file_paths) set(
[
dp["metadata"]["file_path"]
for dp in edges_data
if dp.get("metadata", {}).get("file_path")
]
+ already_file_paths
)
) )
for need_insert_id in [src_id, tgt_id]: for need_insert_id in [src_id, tgt_id]:
@@ -482,7 +491,9 @@ async def extract_entities(
else: else:
return await use_llm_func(input_text) return await use_llm_func(input_text)
async def _process_extraction_result(result: str, chunk_key: str, file_path: str = "unknown_source"): async def _process_extraction_result(
result: str, chunk_key: str, file_path: str = "unknown_source"
):
"""Process a single extraction result (either initial or gleaning) """Process a single extraction result (either initial or gleaning)
Args: Args:
result (str): The extraction result to process result (str): The extraction result to process
@@ -623,7 +634,7 @@ async def extract_entities(
for k, v in maybe_edges.items() for k, v in maybe_edges.items()
] ]
) )
if not (all_entities_data or all_relationships_data): if not (all_entities_data or all_relationships_data):
log_message = "Didn't extract any entities and relationships." log_message = "Didn't extract any entities and relationships."
logger.info(log_message) logger.info(log_message)
@@ -669,7 +680,9 @@ async def extract_entities(
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"), "file_path": dp.get("metadata", {}).get("file_path", "unknown_source"),
"metadata": { "metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time()), "created_at": dp.get("metadata", {}).get("created_at", time.time()),
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"), "file_path": dp.get("metadata", {}).get(
"file_path", "unknown_source"
),
}, },
} }
for dp in all_entities_data for dp in all_entities_data
@@ -687,7 +700,9 @@ async def extract_entities(
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"), "file_path": dp.get("metadata", {}).get("file_path", "unknown_source"),
"metadata": { "metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time()), "created_at": dp.get("metadata", {}).get("created_at", time.time()),
"file_path": dp.get("metadata", {}).get("file_path", "unknown_source"), "file_path": dp.get("metadata", {}).get(
"file_path", "unknown_source"
),
}, },
} }
for dp in all_relationships_data for dp in all_relationships_data
@@ -1272,13 +1287,13 @@ async def _get_node_data(
created_at = n.get("created_at", "UNKNOWN") created_at = n.get("created_at", "UNKNOWN")
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from node data # Get file path from metadata or directly from node data
file_path = n.get("file_path", "unknown_source") file_path = n.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source": if not file_path or file_path == "unknown_source":
# Try to get from metadata # Try to get from metadata
file_path = n.get("metadata", {}).get("file_path", "unknown_source") file_path = n.get("metadata", {}).get("file_path", "unknown_source")
entites_section_list.append( entites_section_list.append(
[ [
i, i,
@@ -1310,13 +1325,13 @@ async def _get_node_data(
# Convert timestamp to readable format # Convert timestamp to readable format
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from edge data # Get file path from metadata or directly from edge data
file_path = e.get("file_path", "unknown_source") file_path = e.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source": if not file_path or file_path == "unknown_source":
# Try to get from metadata # Try to get from metadata
file_path = e.get("metadata", {}).get("file_path", "unknown_source") file_path = e.get("metadata", {}).get("file_path", "unknown_source")
relations_section_list.append( relations_section_list.append(
[ [
i, i,
@@ -1551,13 +1566,13 @@ async def _get_edge_data(
# Convert timestamp to readable format # Convert timestamp to readable format
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from edge data # Get file path from metadata or directly from edge data
file_path = e.get("file_path", "unknown_source") file_path = e.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source": if not file_path or file_path == "unknown_source":
# Try to get from metadata # Try to get from metadata
file_path = e.get("metadata", {}).get("file_path", "unknown_source") file_path = e.get("metadata", {}).get("file_path", "unknown_source")
relations_section_list.append( relations_section_list.append(
[ [
i, i,
@@ -1574,28 +1589,20 @@ async def _get_edge_data(
relations_context = list_of_list_to_csv(relations_section_list) relations_context = list_of_list_to_csv(relations_section_list)
entites_section_list = [ entites_section_list = [
[ ["id", "entity", "type", "description", "rank", "created_at", "file_path"]
"id",
"entity",
"type",
"description",
"rank",
"created_at",
"file_path"
]
] ]
for i, n in enumerate(use_entities): for i, n in enumerate(use_entities):
created_at = n.get("created_at", "Unknown") created_at = n.get("created_at", "Unknown")
# Convert timestamp to readable format # Convert timestamp to readable format
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from node data # Get file path from metadata or directly from node data
file_path = n.get("file_path", "unknown_source") file_path = n.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source": if not file_path or file_path == "unknown_source":
# Try to get from metadata # Try to get from metadata
file_path = n.get("metadata", {}).get("file_path", "unknown_source") file_path = n.get("metadata", {}).get("file_path", "unknown_source")
entites_section_list.append( entites_section_list.append(
[ [
i, i,