Merge branch 'main' of github.com:lcjqyml/LightRAG
This commit is contained in:
@@ -73,6 +73,8 @@ LLM_BINDING_HOST=http://localhost:11434
|
||||
### Embedding Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
EMBEDDING_DIM=1024
|
||||
EMBEDDING_BATCH_NUM=32
|
||||
EMBEDDING_FUNC_MAX_ASYNC=16
|
||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
### ollama example
|
||||
EMBEDDING_BINDING=ollama
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.2.6"
|
||||
__version__ = "1.2.7"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -747,8 +747,30 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get doc_chunks data by id"""
|
||||
raise NotImplementedError
|
||||
"""Get doc_chunks data by multiple IDs."""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
|
||||
params = {"workspace": self.db.workspace, "ids": ids}
|
||||
|
||||
results = await self.db.query(sql, params, True)
|
||||
|
||||
if not results:
|
||||
return []
|
||||
return [
|
||||
{
|
||||
"content": row["content"],
|
||||
"content_length": row["content_length"],
|
||||
"content_summary": row["content_summary"],
|
||||
"status": row["status"],
|
||||
"chunks_count": row["chunks_count"],
|
||||
"created_at": row["created_at"],
|
||||
"updated_at": row["updated_at"],
|
||||
"file_path": row["file_path"],
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
@@ -1570,7 +1592,7 @@ TABLES = {
|
||||
content_vector VECTOR,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
chunk_id TEXT NULL,
|
||||
chunk_ids VARCHAR(255)[] NULL,
|
||||
file_path TEXT NULL,
|
||||
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
@@ -1585,7 +1607,7 @@ TABLES = {
|
||||
content_vector VECTOR,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP,
|
||||
chunk_id TEXT NULL,
|
||||
chunk_ids VARCHAR(255)[] NULL,
|
||||
file_path TEXT NULL,
|
||||
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
|
||||
)"""
|
||||
@@ -1673,7 +1695,7 @@ SQL_TEMPLATES = {
|
||||
""",
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||
content_vector, chunk_ids, file_path)
|
||||
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7::varchar[])
|
||||
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET entity_name=EXCLUDED.entity_name,
|
||||
content=EXCLUDED.content,
|
||||
@@ -1684,7 +1706,7 @@ SQL_TEMPLATES = {
|
||||
""",
|
||||
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
||||
target_id, content, content_vector, chunk_ids, file_path)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8::varchar[])
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET source_id=EXCLUDED.source_id,
|
||||
target_id=EXCLUDED.target_id,
|
||||
|
@@ -183,10 +183,10 @@ class LightRAG:
|
||||
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_batch_num: int = field(default=32)
|
||||
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32)))
|
||||
"""Batch size for embedding computations."""
|
||||
|
||||
embedding_func_max_async: int = field(default=16)
|
||||
embedding_func_max_async: int = field(default=int(os.getenv("EMBEDDING_FUNC_MAX_ASYNC", 16)))
|
||||
"""Maximum number of concurrent embedding function calls."""
|
||||
|
||||
embedding_cache_config: dict[str, Any] = field(
|
||||
@@ -1947,6 +1947,8 @@ class LightRAG:
|
||||
|
||||
# 2. Update entity information in the graph
|
||||
new_node_data = {**node_data, **updated_data}
|
||||
new_node_data["entity_id"] = new_entity_name
|
||||
|
||||
if "entity_name" in new_node_data:
|
||||
del new_node_data[
|
||||
"entity_name"
|
||||
@@ -1963,7 +1965,7 @@ class LightRAG:
|
||||
|
||||
# Store relationships that need to be updated
|
||||
relations_to_update = []
|
||||
|
||||
relations_to_delete = []
|
||||
# Get all edges related to the original entity
|
||||
edges = await self.chunk_entity_relation_graph.get_node_edges(
|
||||
entity_name
|
||||
@@ -1975,6 +1977,12 @@ class LightRAG:
|
||||
source, target
|
||||
)
|
||||
if edge_data:
|
||||
relations_to_delete.append(
|
||||
compute_mdhash_id(source + target, prefix="rel-")
|
||||
)
|
||||
relations_to_delete.append(
|
||||
compute_mdhash_id(target + source, prefix="rel-")
|
||||
)
|
||||
if source == entity_name:
|
||||
await self.chunk_entity_relation_graph.upsert_edge(
|
||||
new_entity_name, target, edge_data
|
||||
@@ -2000,6 +2008,12 @@ class LightRAG:
|
||||
f"Deleted old entity '{entity_name}' and its vector embedding from database"
|
||||
)
|
||||
|
||||
# Delete old relation records from vector database
|
||||
await self.relationships_vdb.delete(relations_to_delete)
|
||||
logger.info(
|
||||
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
|
||||
)
|
||||
|
||||
# Update relationship vector representations
|
||||
for src, tgt, edge_data in relations_to_update:
|
||||
description = edge_data.get("description", "")
|
||||
@@ -2498,39 +2512,21 @@ class LightRAG:
|
||||
# 4. Get all relationships of the source entities
|
||||
all_relations = []
|
||||
for entity_name in source_entities:
|
||||
# Get all relationships where this entity is the source
|
||||
outgoing_edges = await self.chunk_entity_relation_graph.get_node_edges(
|
||||
# Get all relationships of the source entities
|
||||
edges = await self.chunk_entity_relation_graph.get_node_edges(
|
||||
entity_name
|
||||
)
|
||||
if outgoing_edges:
|
||||
for src, tgt in outgoing_edges:
|
||||
if edges:
|
||||
for src, tgt in edges:
|
||||
# Ensure src is the current entity
|
||||
if src == entity_name:
|
||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
||||
src, tgt
|
||||
)
|
||||
all_relations.append(("outgoing", src, tgt, edge_data))
|
||||
|
||||
# Get all relationships where this entity is the target
|
||||
incoming_edges = []
|
||||
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
|
||||
for label in all_labels:
|
||||
if label == entity_name:
|
||||
continue
|
||||
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
|
||||
label
|
||||
)
|
||||
for src, tgt in node_edges or []:
|
||||
if tgt == entity_name:
|
||||
incoming_edges.append((src, tgt))
|
||||
|
||||
for src, tgt in incoming_edges:
|
||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
||||
src, tgt
|
||||
)
|
||||
all_relations.append(("incoming", src, tgt, edge_data))
|
||||
all_relations.append((src, tgt, edge_data))
|
||||
|
||||
# 5. Create or update the target entity
|
||||
merged_entity_data["entity_id"] = target_entity
|
||||
if not target_exists:
|
||||
await self.chunk_entity_relation_graph.upsert_node(
|
||||
target_entity, merged_entity_data
|
||||
@@ -2544,8 +2540,11 @@ class LightRAG:
|
||||
|
||||
# 6. Recreate all relationships, pointing to the target entity
|
||||
relation_updates = {} # Track relationships that need to be merged
|
||||
relations_to_delete = []
|
||||
|
||||
for rel_type, src, tgt, edge_data in all_relations:
|
||||
for src, tgt, edge_data in all_relations:
|
||||
relations_to_delete.append(compute_mdhash_id(src + tgt, prefix="rel-"))
|
||||
relations_to_delete.append(compute_mdhash_id(tgt + src, prefix="rel-"))
|
||||
new_src = target_entity if src in source_entities else src
|
||||
new_tgt = target_entity if tgt in source_entities else tgt
|
||||
|
||||
@@ -2590,6 +2589,12 @@ class LightRAG:
|
||||
f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}"
|
||||
)
|
||||
|
||||
# Delete relationships records from vector database
|
||||
await self.relationships_vdb.delete(relations_to_delete)
|
||||
logger.info(
|
||||
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
|
||||
)
|
||||
|
||||
# 7. Update entity vector representation
|
||||
description = merged_entity_data.get("description", "")
|
||||
source_id = merged_entity_data.get("source_id", "")
|
||||
@@ -2652,19 +2657,6 @@ class LightRAG:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
await self.entities_vdb.delete([entity_id])
|
||||
|
||||
# Also ensure any relationships specific to this entity are deleted from vector DB
|
||||
# This is a safety check, as these should have been transformed to the target entity already
|
||||
entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-")
|
||||
relations_with_entity = await self.relationships_vdb.search_by_prefix(
|
||||
entity_relation_prefix
|
||||
)
|
||||
if relations_with_entity:
|
||||
relation_ids = [r["id"] for r in relations_with_entity]
|
||||
await self.relationships_vdb.delete(relation_ids)
|
||||
logger.info(
|
||||
f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Deleted source entity '{entity_name}' and its vector embedding from database"
|
||||
)
|
||||
|
@@ -138,16 +138,31 @@ async def hf_model_complete(
|
||||
|
||||
|
||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
device = next(embed_model.parameters()).device
|
||||
# Detect the appropriate device
|
||||
if torch.cuda.is_available():
|
||||
device = next(embed_model.parameters()).device # Use CUDA if available
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps") # Use MPS for Apple Silicon
|
||||
else:
|
||||
device = torch.device("cpu") # Fallback to CPU
|
||||
|
||||
# Move the model to the detected device
|
||||
embed_model = embed_model.to(device)
|
||||
|
||||
# Tokenize the input texts and move them to the same device
|
||||
encoded_texts = tokenizer(
|
||||
texts, return_tensors="pt", padding=True, truncation=True
|
||||
).to(device)
|
||||
|
||||
# Perform inference
|
||||
with torch.no_grad():
|
||||
outputs = embed_model(
|
||||
input_ids=encoded_texts["input_ids"],
|
||||
attention_mask=encoded_texts["attention_mask"],
|
||||
)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
# Convert embeddings to NumPy
|
||||
if embeddings.dtype == torch.bfloat16:
|
||||
return embeddings.detach().to(torch.float32).cpu().numpy()
|
||||
else:
|
||||
|
@@ -172,7 +172,7 @@ async def _handle_single_entity_extraction(
|
||||
entity_type=entity_type,
|
||||
description=entity_description,
|
||||
source_id=chunk_key,
|
||||
metadata={"created_at": time.time(), "file_path": file_path},
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ async def _handle_single_relationship_extraction(
|
||||
description=edge_description,
|
||||
keywords=edge_keywords,
|
||||
source_id=edge_source_id,
|
||||
metadata={"created_at": time.time(), "file_path": file_path},
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,9 +224,7 @@ async def _merge_nodes_then_upsert(
|
||||
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_file_paths.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_node["metadata"]["file_path"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP])
|
||||
)
|
||||
already_description.append(already_node["description"])
|
||||
|
||||
@@ -244,7 +242,7 @@ async def _merge_nodes_then_upsert(
|
||||
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
||||
)
|
||||
file_path = GRAPH_FIELD_SEP.join(
|
||||
set([dp["metadata"]["file_path"] for dp in nodes_data] + already_file_paths)
|
||||
set([dp["file_path"] for dp in nodes_data] + already_file_paths)
|
||||
)
|
||||
|
||||
logger.debug(f"file_path: {file_path}")
|
||||
@@ -298,7 +296,7 @@ async def _merge_edges_then_upsert(
|
||||
if already_edge.get("file_path") is not None:
|
||||
already_file_paths.extend(
|
||||
split_string_by_multi_markers(
|
||||
already_edge["metadata"]["file_path"], [GRAPH_FIELD_SEP]
|
||||
already_edge["file_path"], [GRAPH_FIELD_SEP]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -340,11 +338,7 @@ async def _merge_edges_then_upsert(
|
||||
)
|
||||
file_path = GRAPH_FIELD_SEP.join(
|
||||
set(
|
||||
[
|
||||
dp["metadata"]["file_path"]
|
||||
for dp in edges_data
|
||||
if dp.get("metadata", {}).get("file_path")
|
||||
]
|
||||
[dp["file_path"] for dp in edges_data if dp.get("file_path")]
|
||||
+ already_file_paths
|
||||
)
|
||||
)
|
||||
@@ -679,10 +673,6 @@ async def extract_entities(
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
"metadata": {
|
||||
"created_at": dp.get("created_at", time.time()),
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
},
|
||||
}
|
||||
for dp in all_entities_data
|
||||
}
|
||||
@@ -697,10 +687,6 @@ async def extract_entities(
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
"metadata": {
|
||||
"created_at": dp.get("created_at", time.time()),
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
},
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
@@ -1285,11 +1271,8 @@ async def _get_node_data(
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from node data
|
||||
# Get file path from node data
|
||||
file_path = n.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
entites_section_list.append(
|
||||
[
|
||||
@@ -1323,11 +1306,8 @@ async def _get_node_data(
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from edge data
|
||||
# Get file path from edge data
|
||||
file_path = e.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
relations_section_list.append(
|
||||
[
|
||||
@@ -1564,11 +1544,8 @@ async def _get_edge_data(
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from edge data
|
||||
# Get file path from edge data
|
||||
file_path = e.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = e.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
relations_section_list.append(
|
||||
[
|
||||
@@ -1594,11 +1571,8 @@ async def _get_edge_data(
|
||||
if isinstance(created_at, (int, float)):
|
||||
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||
|
||||
# Get file path from metadata or directly from node data
|
||||
# Get file path from node data
|
||||
file_path = n.get("file_path", "unknown_source")
|
||||
if not file_path or file_path == "unknown_source":
|
||||
# Try to get from metadata
|
||||
file_path = n.get("metadata", {}).get("file_path", "unknown_source")
|
||||
|
||||
entites_section_list.append(
|
||||
[
|
||||
|
@@ -109,15 +109,17 @@ def setup_logger(
|
||||
logger_name: str,
|
||||
level: str = "INFO",
|
||||
add_filter: bool = False,
|
||||
log_file_path: str = None,
|
||||
log_file_path: str | None = None,
|
||||
enable_file_logging: bool = True,
|
||||
):
|
||||
"""Set up a logger with console and file handlers
|
||||
"""Set up a logger with console and optionally file handlers
|
||||
|
||||
Args:
|
||||
logger_name: Name of the logger to set up
|
||||
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
add_filter: Whether to add LightragPathFilter to the logger
|
||||
log_file_path: Path to the log file. If None, will use current directory/lightrag.log
|
||||
log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd
|
||||
enable_file_logging: Whether to enable logging to a file (defaults to True)
|
||||
"""
|
||||
# Configure formatters
|
||||
detailed_formatter = logging.Formatter(
|
||||
@@ -125,18 +127,6 @@ def setup_logger(
|
||||
)
|
||||
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||
|
||||
# Get log file path
|
||||
if log_file_path is None:
|
||||
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
||||
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
||||
|
||||
# Get log file max size and backup count from environment variables
|
||||
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
||||
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
||||
|
||||
logger_instance = logging.getLogger(logger_name)
|
||||
logger_instance.setLevel(level)
|
||||
logger_instance.handlers = [] # Clear existing handlers
|
||||
@@ -148,16 +138,34 @@ def setup_logger(
|
||||
console_handler.setLevel(level)
|
||||
logger_instance.addHandler(console_handler)
|
||||
|
||||
# Add file handler
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
filename=log_file_path,
|
||||
maxBytes=log_max_bytes,
|
||||
backupCount=log_backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(detailed_formatter)
|
||||
file_handler.setLevel(level)
|
||||
logger_instance.addHandler(file_handler)
|
||||
# Add file handler by default unless explicitly disabled
|
||||
if enable_file_logging:
|
||||
# Get log file path
|
||||
if log_file_path is None:
|
||||
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
||||
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
||||
|
||||
# Get log file max size and backup count from environment variables
|
||||
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
||||
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
||||
|
||||
try:
|
||||
# Add file handler
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
filename=log_file_path,
|
||||
maxBytes=log_max_bytes,
|
||||
backupCount=log_backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(detailed_formatter)
|
||||
file_handler.setLevel(level)
|
||||
logger_instance.addHandler(file_handler)
|
||||
except PermissionError as e:
|
||||
logger.warning(f"Could not create log file at {log_file_path}: {str(e)}")
|
||||
logger.warning("Continuing with console logging only")
|
||||
|
||||
# Add path filter if requested
|
||||
if add_filter:
|
||||
|
Reference in New Issue
Block a user