Merge branch 'main' of github.com:lcjqyml/LightRAG

This commit is contained in:
Milin
2025-03-21 15:22:10 +08:00
7 changed files with 123 additions and 110 deletions

View File

@@ -73,6 +73,8 @@ LLM_BINDING_HOST=http://localhost:11434
### Embedding Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024
EMBEDDING_BATCH_NUM=32
EMBEDDING_FUNC_MAX_ASYNC=16
# EMBEDDING_BINDING_API_KEY=your_api_key
### ollama example
EMBEDDING_BINDING=ollama

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.2.6"
__version__ = "1.2.7"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -747,8 +747,30 @@ class PGDocStatusStorage(DocStatusStorage):
)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id"""
raise NotImplementedError
"""Get doc_chunks data by multiple IDs."""
if not ids:
return []
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.db.workspace, "ids": ids}
results = await self.db.query(sql, params, True)
if not results:
return []
return [
{
"content": row["content"],
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"file_path": row["file_path"],
}
for row in results
]
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
@@ -1570,7 +1592,7 @@ TABLES = {
content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
chunk_id TEXT NULL,
chunk_ids VARCHAR(255)[] NULL,
file_path TEXT NULL,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)"""
@@ -1585,7 +1607,7 @@ TABLES = {
content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
chunk_id TEXT NULL,
chunk_ids VARCHAR(255)[] NULL,
file_path TEXT NULL,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)"""
@@ -1673,7 +1695,7 @@ SQL_TEMPLATES = {
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path)
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7::varchar[])
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
ON CONFLICT (workspace,id) DO UPDATE
SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content,
@@ -1684,7 +1706,7 @@ SQL_TEMPLATES = {
""",
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
target_id, content, content_vector, chunk_ids, file_path)
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8::varchar[])
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8)
ON CONFLICT (workspace,id) DO UPDATE
SET source_id=EXCLUDED.source_id,
target_id=EXCLUDED.target_id,

View File

@@ -183,10 +183,10 @@ class LightRAG:
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = field(default=32)
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32)))
"""Batch size for embedding computations."""
embedding_func_max_async: int = field(default=16)
embedding_func_max_async: int = field(default=int(os.getenv("EMBEDDING_FUNC_MAX_ASYNC", 16)))
"""Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field(
@@ -1947,6 +1947,8 @@ class LightRAG:
# 2. Update entity information in the graph
new_node_data = {**node_data, **updated_data}
new_node_data["entity_id"] = new_entity_name
if "entity_name" in new_node_data:
del new_node_data[
"entity_name"
@@ -1963,7 +1965,7 @@ class LightRAG:
# Store relationships that need to be updated
relations_to_update = []
relations_to_delete = []
# Get all edges related to the original entity
edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name
@@ -1975,6 +1977,12 @@ class LightRAG:
source, target
)
if edge_data:
relations_to_delete.append(
compute_mdhash_id(source + target, prefix="rel-")
)
relations_to_delete.append(
compute_mdhash_id(target + source, prefix="rel-")
)
if source == entity_name:
await self.chunk_entity_relation_graph.upsert_edge(
new_entity_name, target, edge_data
@@ -2000,6 +2008,12 @@ class LightRAG:
f"Deleted old entity '{entity_name}' and its vector embedding from database"
)
# Delete old relation records from vector database
await self.relationships_vdb.delete(relations_to_delete)
logger.info(
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
)
# Update relationship vector representations
for src, tgt, edge_data in relations_to_update:
description = edge_data.get("description", "")
@@ -2498,39 +2512,21 @@ class LightRAG:
# 4. Get all relationships of the source entities
all_relations = []
for entity_name in source_entities:
# Get all relationships where this entity is the source
outgoing_edges = await self.chunk_entity_relation_graph.get_node_edges(
# Get all relationships of the source entities
edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name
)
if outgoing_edges:
for src, tgt in outgoing_edges:
if edges:
for src, tgt in edges:
# Ensure src is the current entity
if src == entity_name:
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
all_relations.append(("outgoing", src, tgt, edge_data))
# Get all relationships where this entity is the target
incoming_edges = []
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
for label in all_labels:
if label == entity_name:
continue
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
label
)
for src, tgt in node_edges or []:
if tgt == entity_name:
incoming_edges.append((src, tgt))
for src, tgt in incoming_edges:
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
all_relations.append(("incoming", src, tgt, edge_data))
all_relations.append((src, tgt, edge_data))
# 5. Create or update the target entity
merged_entity_data["entity_id"] = target_entity
if not target_exists:
await self.chunk_entity_relation_graph.upsert_node(
target_entity, merged_entity_data
@@ -2544,8 +2540,11 @@ class LightRAG:
# 6. Recreate all relationships, pointing to the target entity
relation_updates = {} # Track relationships that need to be merged
relations_to_delete = []
for rel_type, src, tgt, edge_data in all_relations:
for src, tgt, edge_data in all_relations:
relations_to_delete.append(compute_mdhash_id(src + tgt, prefix="rel-"))
relations_to_delete.append(compute_mdhash_id(tgt + src, prefix="rel-"))
new_src = target_entity if src in source_entities else src
new_tgt = target_entity if tgt in source_entities else tgt
@@ -2590,6 +2589,12 @@ class LightRAG:
f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}"
)
# Delete relationships records from vector database
await self.relationships_vdb.delete(relations_to_delete)
logger.info(
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
)
# 7. Update entity vector representation
description = merged_entity_data.get("description", "")
source_id = merged_entity_data.get("source_id", "")
@@ -2652,19 +2657,6 @@ class LightRAG:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([entity_id])
# Also ensure any relationships specific to this entity are deleted from vector DB
# This is a safety check, as these should have been transformed to the target entity already
entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-")
relations_with_entity = await self.relationships_vdb.search_by_prefix(
entity_relation_prefix
)
if relations_with_entity:
relation_ids = [r["id"] for r in relations_with_entity]
await self.relationships_vdb.delete(relation_ids)
logger.info(
f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database"
)
logger.info(
f"Deleted source entity '{entity_name}' and its vector embedding from database"
)

View File

@@ -138,16 +138,31 @@ async def hf_model_complete(
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device
# Detect the appropriate device
if torch.cuda.is_available():
device = next(embed_model.parameters()).device # Use CUDA if available
elif torch.backends.mps.is_available():
device = torch.device("mps") # Use MPS for Apple Silicon
else:
device = torch.device("cpu") # Fallback to CPU
# Move the model to the detected device
embed_model = embed_model.to(device)
# Tokenize the input texts and move them to the same device
encoded_texts = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).to(device)
# Perform inference
with torch.no_grad():
outputs = embed_model(
input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"],
)
embeddings = outputs.last_hidden_state.mean(dim=1)
# Convert embeddings to NumPy
if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy()
else:

View File

@@ -172,7 +172,7 @@ async def _handle_single_entity_extraction(
entity_type=entity_type,
description=entity_description,
source_id=chunk_key,
metadata={"created_at": time.time(), "file_path": file_path},
file_path=file_path,
)
@@ -201,7 +201,7 @@ async def _handle_single_relationship_extraction(
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
metadata={"created_at": time.time(), "file_path": file_path},
file_path=file_path,
)
@@ -224,9 +224,7 @@ async def _merge_nodes_then_upsert(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
)
already_file_paths.extend(
split_string_by_multi_markers(
already_node["metadata"]["file_path"], [GRAPH_FIELD_SEP]
)
split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP])
)
already_description.append(already_node["description"])
@@ -244,7 +242,7 @@ async def _merge_nodes_then_upsert(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
)
file_path = GRAPH_FIELD_SEP.join(
set([dp["metadata"]["file_path"] for dp in nodes_data] + already_file_paths)
set([dp["file_path"] for dp in nodes_data] + already_file_paths)
)
logger.debug(f"file_path: {file_path}")
@@ -298,7 +296,7 @@ async def _merge_edges_then_upsert(
if already_edge.get("file_path") is not None:
already_file_paths.extend(
split_string_by_multi_markers(
already_edge["metadata"]["file_path"], [GRAPH_FIELD_SEP]
already_edge["file_path"], [GRAPH_FIELD_SEP]
)
)
@@ -340,11 +338,7 @@ async def _merge_edges_then_upsert(
)
file_path = GRAPH_FIELD_SEP.join(
set(
[
dp["metadata"]["file_path"]
for dp in edges_data
if dp.get("metadata", {}).get("file_path")
]
[dp["file_path"] for dp in edges_data if dp.get("file_path")]
+ already_file_paths
)
)
@@ -679,10 +673,6 @@ async def extract_entities(
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
"metadata": {
"created_at": dp.get("created_at", time.time()),
"file_path": dp.get("file_path", "unknown_source"),
},
}
for dp in all_entities_data
}
@@ -697,10 +687,6 @@ async def extract_entities(
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
"metadata": {
"created_at": dp.get("created_at", time.time()),
"file_path": dp.get("file_path", "unknown_source"),
},
}
for dp in all_relationships_data
}
@@ -1285,11 +1271,8 @@ async def _get_node_data(
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from node data
# Get file path from node data
file_path = n.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
entites_section_list.append(
[
@@ -1323,11 +1306,8 @@ async def _get_node_data(
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from edge data
# Get file path from edge data
file_path = e.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
relations_section_list.append(
[
@@ -1564,11 +1544,8 @@ async def _get_edge_data(
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from edge data
# Get file path from edge data
file_path = e.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
relations_section_list.append(
[
@@ -1594,11 +1571,8 @@ async def _get_edge_data(
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from metadata or directly from node data
# Get file path from node data
file_path = n.get("file_path", "unknown_source")
if not file_path or file_path == "unknown_source":
# Try to get from metadata
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
entites_section_list.append(
[

View File

@@ -109,15 +109,17 @@ def setup_logger(
logger_name: str,
level: str = "INFO",
add_filter: bool = False,
log_file_path: str = None,
log_file_path: str | None = None,
enable_file_logging: bool = True,
):
"""Set up a logger with console and file handlers
"""Set up a logger with console and optionally file handlers
Args:
logger_name: Name of the logger to set up
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
add_filter: Whether to add LightragPathFilter to the logger
log_file_path: Path to the log file. If None, will use current directory/lightrag.log
log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd
enable_file_logging: Whether to enable logging to a file (defaults to True)
"""
# Configure formatters
detailed_formatter = logging.Formatter(
@@ -125,18 +127,6 @@ def setup_logger(
)
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
# Get log file path
if log_file_path is None:
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logger_instance = logging.getLogger(logger_name)
logger_instance.setLevel(level)
logger_instance.handlers = [] # Clear existing handlers
@@ -148,16 +138,34 @@ def setup_logger(
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
# Add file handler by default unless explicitly disabled
if enable_file_logging:
# Get log file path
if log_file_path is None:
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
try:
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
except PermissionError as e:
logger.warning(f"Could not create log file at {log_file_path}: {str(e)}")
logger.warning("Continuing with console logging only")
# Add path filter if requested
if add_filter: