Merge branch 'main' of github.com:lcjqyml/LightRAG

This commit is contained in:
Milin
2025-03-21 15:22:10 +08:00
7 changed files with 123 additions and 110 deletions

View File

@@ -73,6 +73,8 @@ LLM_BINDING_HOST=http://localhost:11434
### Embedding Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal) ### Embedding Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
EMBEDDING_MODEL=bge-m3:latest EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
EMBEDDING_BATCH_NUM=32
EMBEDDING_FUNC_MAX_ASYNC=16
# EMBEDDING_BINDING_API_KEY=your_api_key # EMBEDDING_BINDING_API_KEY=your_api_key
### ollama example ### ollama example
EMBEDDING_BINDING=ollama EMBEDDING_BINDING=ollama

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.2.6" __version__ = "1.2.7"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -747,8 +747,30 @@ class PGDocStatusStorage(DocStatusStorage):
) )
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id""" """Get doc_chunks data by multiple IDs."""
raise NotImplementedError if not ids:
return []
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.db.workspace, "ids": ids}
results = await self.db.query(sql, params, True)
if not results:
return []
return [
{
"content": row["content"],
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"file_path": row["file_path"],
}
for row in results
]
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
@@ -1570,7 +1592,7 @@ TABLES = {
content_vector VECTOR, content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP, update_time TIMESTAMP,
chunk_id TEXT NULL, chunk_ids VARCHAR(255)[] NULL,
file_path TEXT NULL, file_path TEXT NULL,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)""" )"""
@@ -1585,7 +1607,7 @@ TABLES = {
content_vector VECTOR, content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP, update_time TIMESTAMP,
chunk_id TEXT NULL, chunk_ids VARCHAR(255)[] NULL,
file_path TEXT NULL, file_path TEXT NULL,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)""" )"""
@@ -1673,7 +1695,7 @@ SQL_TEMPLATES = {
""", """,
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path) content_vector, chunk_ids, file_path)
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7::varchar[]) VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET entity_name=EXCLUDED.entity_name, SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content, content=EXCLUDED.content,
@@ -1684,7 +1706,7 @@ SQL_TEMPLATES = {
""", """,
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
target_id, content, content_vector, chunk_ids, file_path) target_id, content, content_vector, chunk_ids, file_path)
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8::varchar[]) VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8)
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET source_id=EXCLUDED.source_id, SET source_id=EXCLUDED.source_id,
target_id=EXCLUDED.target_id, target_id=EXCLUDED.target_id,

View File

@@ -183,10 +183,10 @@ class LightRAG:
embedding_func: EmbeddingFunc | None = field(default=None) embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use.""" """Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = field(default=32) embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32)))
"""Batch size for embedding computations.""" """Batch size for embedding computations."""
embedding_func_max_async: int = field(default=16) embedding_func_max_async: int = field(default=int(os.getenv("EMBEDDING_FUNC_MAX_ASYNC", 16)))
"""Maximum number of concurrent embedding function calls.""" """Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field( embedding_cache_config: dict[str, Any] = field(
@@ -1947,6 +1947,8 @@ class LightRAG:
# 2. Update entity information in the graph # 2. Update entity information in the graph
new_node_data = {**node_data, **updated_data} new_node_data = {**node_data, **updated_data}
new_node_data["entity_id"] = new_entity_name
if "entity_name" in new_node_data: if "entity_name" in new_node_data:
del new_node_data[ del new_node_data[
"entity_name" "entity_name"
@@ -1963,7 +1965,7 @@ class LightRAG:
# Store relationships that need to be updated # Store relationships that need to be updated
relations_to_update = [] relations_to_update = []
relations_to_delete = []
# Get all edges related to the original entity # Get all edges related to the original entity
edges = await self.chunk_entity_relation_graph.get_node_edges( edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name entity_name
@@ -1975,6 +1977,12 @@ class LightRAG:
source, target source, target
) )
if edge_data: if edge_data:
relations_to_delete.append(
compute_mdhash_id(source + target, prefix="rel-")
)
relations_to_delete.append(
compute_mdhash_id(target + source, prefix="rel-")
)
if source == entity_name: if source == entity_name:
await self.chunk_entity_relation_graph.upsert_edge( await self.chunk_entity_relation_graph.upsert_edge(
new_entity_name, target, edge_data new_entity_name, target, edge_data
@@ -2000,6 +2008,12 @@ class LightRAG:
f"Deleted old entity '{entity_name}' and its vector embedding from database" f"Deleted old entity '{entity_name}' and its vector embedding from database"
) )
# Delete old relation records from vector database
await self.relationships_vdb.delete(relations_to_delete)
logger.info(
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
)
# Update relationship vector representations # Update relationship vector representations
for src, tgt, edge_data in relations_to_update: for src, tgt, edge_data in relations_to_update:
description = edge_data.get("description", "") description = edge_data.get("description", "")
@@ -2498,39 +2512,21 @@ class LightRAG:
# 4. Get all relationships of the source entities # 4. Get all relationships of the source entities
all_relations = [] all_relations = []
for entity_name in source_entities: for entity_name in source_entities:
# Get all relationships where this entity is the source # Get all relationships of the source entities
outgoing_edges = await self.chunk_entity_relation_graph.get_node_edges( edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name entity_name
) )
if outgoing_edges: if edges:
for src, tgt in outgoing_edges: for src, tgt in edges:
# Ensure src is the current entity # Ensure src is the current entity
if src == entity_name: if src == entity_name:
edge_data = await self.chunk_entity_relation_graph.get_edge( edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt src, tgt
) )
all_relations.append(("outgoing", src, tgt, edge_data)) all_relations.append((src, tgt, edge_data))
# Get all relationships where this entity is the target
incoming_edges = []
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
for label in all_labels:
if label == entity_name:
continue
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
label
)
for src, tgt in node_edges or []:
if tgt == entity_name:
incoming_edges.append((src, tgt))
for src, tgt in incoming_edges:
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
all_relations.append(("incoming", src, tgt, edge_data))
# 5. Create or update the target entity # 5. Create or update the target entity
merged_entity_data["entity_id"] = target_entity
if not target_exists: if not target_exists:
await self.chunk_entity_relation_graph.upsert_node( await self.chunk_entity_relation_graph.upsert_node(
target_entity, merged_entity_data target_entity, merged_entity_data
@@ -2544,8 +2540,11 @@ class LightRAG:
# 6. Recreate all relationships, pointing to the target entity # 6. Recreate all relationships, pointing to the target entity
relation_updates = {} # Track relationships that need to be merged relation_updates = {} # Track relationships that need to be merged
relations_to_delete = []
for rel_type, src, tgt, edge_data in all_relations: for src, tgt, edge_data in all_relations:
relations_to_delete.append(compute_mdhash_id(src + tgt, prefix="rel-"))
relations_to_delete.append(compute_mdhash_id(tgt + src, prefix="rel-"))
new_src = target_entity if src in source_entities else src new_src = target_entity if src in source_entities else src
new_tgt = target_entity if tgt in source_entities else tgt new_tgt = target_entity if tgt in source_entities else tgt
@@ -2590,6 +2589,12 @@ class LightRAG:
f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}" f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}"
) )
# Delete relationships records from vector database
await self.relationships_vdb.delete(relations_to_delete)
logger.info(
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
)
# 7. Update entity vector representation # 7. Update entity vector representation
description = merged_entity_data.get("description", "") description = merged_entity_data.get("description", "")
source_id = merged_entity_data.get("source_id", "") source_id = merged_entity_data.get("source_id", "")
@@ -2652,19 +2657,6 @@ class LightRAG:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([entity_id]) await self.entities_vdb.delete([entity_id])
# Also ensure any relationships specific to this entity are deleted from vector DB
# This is a safety check, as these should have been transformed to the target entity already
entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-")
relations_with_entity = await self.relationships_vdb.search_by_prefix(
entity_relation_prefix
)
if relations_with_entity:
relation_ids = [r["id"] for r in relations_with_entity]
await self.relationships_vdb.delete(relation_ids)
logger.info(
f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database"
)
logger.info( logger.info(
f"Deleted source entity '{entity_name}' and its vector embedding from database" f"Deleted source entity '{entity_name}' and its vector embedding from database"
) )

View File

@@ -138,16 +138,31 @@ async def hf_model_complete(
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device # Detect the appropriate device
if torch.cuda.is_available():
device = next(embed_model.parameters()).device # Use CUDA if available
elif torch.backends.mps.is_available():
device = torch.device("mps") # Use MPS for Apple Silicon
else:
device = torch.device("cpu") # Fallback to CPU
# Move the model to the detected device
embed_model = embed_model.to(device)
# Tokenize the input texts and move them to the same device
encoded_texts = tokenizer( encoded_texts = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True texts, return_tensors="pt", padding=True, truncation=True
).to(device) ).to(device)
# Perform inference
with torch.no_grad(): with torch.no_grad():
outputs = embed_model( outputs = embed_model(
input_ids=encoded_texts["input_ids"], input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"], attention_mask=encoded_texts["attention_mask"],
) )
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
# Convert embeddings to NumPy
if embeddings.dtype == torch.bfloat16: if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy() return embeddings.detach().to(torch.float32).cpu().numpy()
else: else:

View File

@@ -172,7 +172,7 @@ async def _handle_single_entity_extraction(
entity_type=entity_type, entity_type=entity_type,
description=entity_description, description=entity_description,
source_id=chunk_key, source_id=chunk_key,
metadata={"created_at": time.time(), "file_path": file_path}, file_path=file_path,
) )
@@ -201,7 +201,7 @@ async def _handle_single_relationship_extraction(
description=edge_description, description=edge_description,
keywords=edge_keywords, keywords=edge_keywords,
source_id=edge_source_id, source_id=edge_source_id,
metadata={"created_at": time.time(), "file_path": file_path}, file_path=file_path,
) )
@@ -224,9 +224,7 @@ async def _merge_nodes_then_upsert(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
) )
already_file_paths.extend( already_file_paths.extend(
split_string_by_multi_markers( split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP])
already_node["metadata"]["file_path"], [GRAPH_FIELD_SEP]
)
) )
already_description.append(already_node["description"]) already_description.append(already_node["description"])
@@ -244,7 +242,7 @@ async def _merge_nodes_then_upsert(
set([dp["source_id"] for dp in nodes_data] + already_source_ids) set([dp["source_id"] for dp in nodes_data] + already_source_ids)
) )
file_path = GRAPH_FIELD_SEP.join( file_path = GRAPH_FIELD_SEP.join(
set([dp["metadata"]["file_path"] for dp in nodes_data] + already_file_paths) set([dp["file_path"] for dp in nodes_data] + already_file_paths)
) )
logger.debug(f"file_path: {file_path}") logger.debug(f"file_path: {file_path}")
@@ -298,7 +296,7 @@ async def _merge_edges_then_upsert(
if already_edge.get("file_path") is not None: if already_edge.get("file_path") is not None:
already_file_paths.extend( already_file_paths.extend(
split_string_by_multi_markers( split_string_by_multi_markers(
already_edge["metadata"]["file_path"], [GRAPH_FIELD_SEP] already_edge["file_path"], [GRAPH_FIELD_SEP]
) )
) )
@@ -340,11 +338,7 @@ async def _merge_edges_then_upsert(
) )
file_path = GRAPH_FIELD_SEP.join( file_path = GRAPH_FIELD_SEP.join(
set( set(
[ [dp["file_path"] for dp in edges_data if dp.get("file_path")]
dp["metadata"]["file_path"]
for dp in edges_data
if dp.get("metadata", {}).get("file_path")
]
+ already_file_paths + already_file_paths
) )
) )
@@ -679,10 +673,6 @@ async def extract_entities(
"content": f"{dp['entity_name']}\n{dp['description']}", "content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"], "source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"), "file_path": dp.get("file_path", "unknown_source"),
"metadata": {
"created_at": dp.get("created_at", time.time()),
"file_path": dp.get("file_path", "unknown_source"),
},
} }
for dp in all_entities_data for dp in all_entities_data
} }
@@ -697,10 +687,6 @@ async def extract_entities(
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"], "source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"), "file_path": dp.get("file_path", "unknown_source"),
"metadata": {
"created_at": dp.get("created_at", time.time()),
"file_path": dp.get("file_path", "unknown_source"),
},
} }
for dp in all_relationships_data for dp in all_relationships_data
} }
@@ -1285,11 +1271,8 @@ async def _get_node_data(
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from node data # Get file path from node data
file_path = n.get("file_path", "unknown_source") file_path = n.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
entites_section_list.append( entites_section_list.append(
[ [
@@ -1323,11 +1306,8 @@ async def _get_node_data(
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from edge data # Get file path from edge data
file_path = e.get("file_path", "unknown_source") file_path = e.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
relations_section_list.append( relations_section_list.append(
[ [
@@ -1564,11 +1544,8 @@ async def _get_edge_data(
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from edge data # Get file path from edge data
file_path = e.get("file_path", "unknown_source") file_path = e.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
relations_section_list.append( relations_section_list.append(
[ [
@@ -1594,11 +1571,8 @@ async def _get_edge_data(
if isinstance(created_at, (int, float)): if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from node data # Get file path from node data
file_path = n.get("file_path", "unknown_source") file_path = n.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
entites_section_list.append( entites_section_list.append(
[ [

View File

@@ -109,15 +109,17 @@ def setup_logger(
logger_name: str, logger_name: str,
level: str = "INFO", level: str = "INFO",
add_filter: bool = False, add_filter: bool = False,
log_file_path: str = None, log_file_path: str | None = None,
enable_file_logging: bool = True,
): ):
"""Set up a logger with console and file handlers """Set up a logger with console and optionally file handlers
Args: Args:
logger_name: Name of the logger to set up logger_name: Name of the logger to set up
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
add_filter: Whether to add LightragPathFilter to the logger add_filter: Whether to add LightragPathFilter to the logger
log_file_path: Path to the log file. If None, will use current directory/lightrag.log log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd
enable_file_logging: Whether to enable logging to a file (defaults to True)
""" """
# Configure formatters # Configure formatters
detailed_formatter = logging.Formatter( detailed_formatter = logging.Formatter(
@@ -125,18 +127,6 @@ def setup_logger(
) )
simple_formatter = logging.Formatter("%(levelname)s: %(message)s") simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
# Get log file path
if log_file_path is None:
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logger_instance = logging.getLogger(logger_name) logger_instance = logging.getLogger(logger_name)
logger_instance.setLevel(level) logger_instance.setLevel(level)
logger_instance.handlers = [] # Clear existing handlers logger_instance.handlers = [] # Clear existing handlers
@@ -148,16 +138,34 @@ def setup_logger(
console_handler.setLevel(level) console_handler.setLevel(level)
logger_instance.addHandler(console_handler) logger_instance.addHandler(console_handler)
# Add file handler # Add file handler by default unless explicitly disabled
file_handler = logging.handlers.RotatingFileHandler( if enable_file_logging:
filename=log_file_path, # Get log file path
maxBytes=log_max_bytes, if log_file_path is None:
backupCount=log_backup_count, log_dir = os.getenv("LOG_DIR", os.getcwd())
encoding="utf-8", log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
)
file_handler.setFormatter(detailed_formatter) # Ensure log directory exists
file_handler.setLevel(level) os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
logger_instance.addHandler(file_handler)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
try:
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
except PermissionError as e:
logger.warning(f"Could not create log file at {log_file_path}: {str(e)}")
logger.warning("Continuing with console logging only")
# Add path filter if requested # Add path filter if requested
if add_filter: if add_filter: