main_merge

This commit is contained in:
Roy
2025-03-08 20:34:29 +00:00
39 changed files with 1475 additions and 244 deletions

View File

@@ -229,3 +229,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Get all records from the collection
# Since ChromaDB doesn't directly support prefix search on IDs,
# we'll get all records and filter in Python
results = self._collection.get(
include=["metadatas", "documents", "embeddings"]
)
matching_records = []
# Filter records where ID starts with the prefix
for i, record_id in enumerate(results["ids"]):
if record_id.startswith(prefix):
matching_records.append(
{
"id": record_id,
"content": results["documents"][i],
"vector": results["embeddings"][i],
**results["metadatas"][i],
}
)
logger.debug(
f"Found {len(matching_records)} records with prefix '{prefix}'"
)
return matching_records
except Exception as e:
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
raise

View File

@@ -371,3 +371,24 @@ class FaissVectorDBStorage(BaseVectorStorage):
return False # Return error
return True # Return success
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
matching_records = []
# Search for records with IDs starting with the prefix
for faiss_id, meta in self._id_to_meta.items():
if "__id__" in meta and meta["__id__"].startswith(prefix):
# Create a copy of all metadata and add "id" field
record = {**meta, "id": meta["__id__"]}
matching_records.append(record)
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
return matching_records

View File

@@ -206,3 +206,28 @@ class MilvusVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use Milvus query with expression to find IDs with the given prefix
expression = f'id like "{prefix}%"'
results = self._client.query(
collection_name=self.namespace,
filter=expression,
output_fields=list(self.meta_fields) + ["id"],
)
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
return results
except Exception as e:
logger.error(f"Error searching for records with prefix '{prefix}': {e}")
return []

View File

@@ -1045,6 +1045,32 @@ class MongoVectorDBStorage(BaseVectorStorage):
except PyMongoError as e:
logger.error(f"Error deleting relations for {entity_name}: {str(e)}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use MongoDB regex to find documents where _id starts with the prefix
cursor = self._data.find({"_id": {"$regex": f"^{prefix}"}})
matching_records = await cursor.to_list(length=None)
# Format results
results = [{**doc, "id": doc["_id"]} for doc in matching_records]
logger.debug(
f"Found {len(results)} records with prefix '{prefix}' in {self.namespace}"
)
return results
except PyMongoError as e:
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
return []
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
collection_names = await db.list_collection_names()

View File

@@ -236,3 +236,23 @@ class NanoVectorDBStorage(BaseVectorStorage):
return False # Return error
return True # Return success
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
storage = await self.client_storage
matching_records = []
# Search for records with IDs starting with the prefix
for record in storage["data"]:
if "__id__" in record and record["__id__"].startswith(prefix):
matching_records.append({**record, "id": record["__id__"]})
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
return matching_records

View File

@@ -232,19 +232,26 @@ class NetworkXStorage(BaseGraphStorage):
return sorted(list(labels))
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
Returns:
KnowledgeGraph object containing nodes and edges
@@ -255,6 +262,10 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
@@ -262,11 +273,16 @@ class NetworkXStorage(BaseGraphStorage):
graph.copy()
) # Create a copy to avoid modifying the original graph
else:
# Find nodes with matching node id (partial match)
# Find nodes with matching node id based on search_mode
nodes_to_explore = []
for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching
nodes_to_explore.append(n)
node_str = str(n)
if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}")
@@ -277,26 +293,37 @@ class NetworkXStorage(BaseGraphStorage):
for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Get start nodes and direct connected nodes
if nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(
combined_subgraph.neighbors(start_node)
)
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
@@ -356,7 +383,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="RELATED",
type="DIRECTED",
source=str(source),
target=str(target),
properties=edge_data,

View File

@@ -494,6 +494,41 @@ class OracleVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
raise
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Determine the appropriate table based on namespace
table_name = namespace_to_table_name(self.namespace)
# Create SQL query to find records with IDs starting with prefix
search_sql = f"""
SELECT * FROM {table_name}
WHERE workspace = :workspace
AND id LIKE :prefix_pattern
ORDER BY id
"""
params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
# Execute query and get results
results = await self.db.query(search_sql, params, multirows=True)
logger.debug(
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
)
return results or []
except Exception as e:
logger.error(f"Error searching records with prefix '{prefix}': {e}")
return []
@final
@dataclass

View File

@@ -585,6 +585,41 @@ class PGVectorStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for prefix search: {self.namespace}")
return []
search_sql = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id LIKE $2"
params = {"workspace": self.db.workspace, "prefix": f"{prefix}%"}
try:
results = await self.db.query(search_sql, params, multirows=True)
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
# Format results to match the expected return format
formatted_results = []
for record in results:
formatted_record = dict(record)
# Ensure id field is available (for consistency with NanoVectorDB implementation)
if "id" not in formatted_record:
formatted_record["id"] = record["id"]
formatted_results.append(formatted_record)
return formatted_results
except Exception as e:
logger.error(f"Error during prefix search for '{prefix}': {e}")
return []
@final
@dataclass
@@ -785,42 +820,85 @@ class PGGraphStorage(BaseGraphStorage):
v = record[k]
# agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
if v.startswith("[") and v.endswith("]"):
if "::vertex" not in v:
continue
v = v.replace("::vertex", "")
vertexes = json.loads(v)
for vertex in vertexes:
vertices[vertex["id"]] = vertex.get("properties")
else:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
# iterate returned fields and parse appropriately
for k in record.keys():
v = record[k]
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
else:
dtype = ""
if v.startswith("[") and v.endswith("]"):
if "::vertex" in v:
v = v.replace("::vertex", "")
vertexes = json.loads(v)
dl = []
for vertex in vertexes:
prop = vertex.get("properties")
if not prop:
prop = {}
prop["label"] = PGGraphStorage._decode_graph_label(
prop["node_id"]
)
dl.append(prop)
d[k] = dl
if dtype == "vertex":
vertex = json.loads(v)
field = vertex.get("properties")
if not field:
field = {}
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
elif "::edge" in v:
v = v.replace("::edge", "")
edges = json.loads(v)
dl = []
for edge in edges:
dl.append(
(
vertices[edge["start_id"]],
edge["label"],
vertices[edge["end_id"]],
)
)
d[k] = dl
else:
print("WARNING: unsupported type")
continue
else:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
field = vertex.get("properties")
if not field:
field = {}
field["label"] = PGGraphStorage._decode_graph_label(
field["node_id"]
)
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
else:
d[k] = json.loads(v) if isinstance(v, str) else v
if v is None or (v.count("{") < 1 and v.count("[") < 1):
d[k] = v
else:
d[k] = json.loads(v) if isinstance(v, str) else v
return d
@@ -1294,7 +1372,7 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (
$$) AS (nodes agtype, relationships agtype)""" % (
self.graph_name,
encoded_node_label,
max_depth,
@@ -1303,17 +1381,23 @@ class PGGraphStorage(BaseGraphStorage):
results = await self._query(query)
nodes = set()
nodes = {}
edges = []
unique_edge_ids = set()
for result in results:
if node_label == "*":
if result["n"]:
node = result["n"]
nodes.add(self._decode_graph_label(node["node_id"]))
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["m"]:
node = result["m"]
nodes.add(self._decode_graph_label(node["node_id"]))
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["r"]:
edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"])
@@ -1322,16 +1406,36 @@ class PGGraphStorage(BaseGraphStorage):
else:
if result["nodes"]:
for node in result["nodes"]:
nodes.add(self._decode_graph_label(node["node_id"]))
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["relationships"]:
for edge in result["relationships"]:
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
for edge in result["relationships"]: # src --DIRECTED--> target
src_id = self._decode_graph_label(edge[0]["node_id"])
tgt_id = self._decode_graph_label(edge[2]["node_id"])
id = src_id + "," + tgt_id
if id in unique_edge_ids:
continue
else:
unique_edge_ids.add(id)
edges.append(
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
)
kg = KnowledgeGraph(
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
nodes=[
KnowledgeGraphNode(
id=node_id, labels=[node_id], properties=nodes[node_id]
)
for node_id in nodes
],
edges=[
KnowledgeGraphEdge(
id=id, type="DIRECTED", source=src, target=tgt, properties=props
)
for id, src, tgt, props in edges
],
)
return kg

View File

@@ -135,7 +135,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
logger.debug(f"query result: {results}")
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
return [{**dp.payload, "distance": dp.score} for dp in results]
async def index_done_callback(self) -> None:
# Qdrant handles persistence automatically
@@ -233,3 +233,43 @@ class QdrantVectorDBStorage(BaseVectorStorage):
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
try:
# Use scroll method to find records with IDs starting with the prefix
results = self._client.scroll(
collection_name=self.namespace,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="id", match=models.MatchText(text=prefix, prefix=True)
)
]
),
with_payload=True,
with_vectors=False,
limit=1000, # Adjust as needed for your use case
)
# Extract matching points
matching_records = results[0]
# Format the results to match expected return format
formatted_results = [{**point.payload} for point in matching_records]
logger.debug(
f"Found {len(formatted_results)} records with prefix '{prefix}'"
)
return formatted_results
except Exception as e:
logger.error(f"Error searching for prefix '{prefix}': {e}")
return []

View File

@@ -414,6 +414,55 @@ class TiDBVectorDBStorage(BaseVectorStorage):
# Ti handles persistence automatically
pass
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
# Determine which table to query based on namespace
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
sql_template = """
SELECT entity_id as id, name as entity_name, entity_type, description, content
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
sql_template = """
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
keywords, description, content
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
sql_template = """
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
"""
else:
logger.warning(
f"Namespace {self.namespace} not supported for prefix search"
)
return []
# Add prefix pattern parameter with % for SQL LIKE
prefix_pattern = f"{prefix}%"
params = {"prefix_pattern": prefix_pattern, "workspace": self.db.workspace}
try:
results = await self.db.query(sql_template, params=params, multirows=True)
logger.debug(
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
)
return results if results else []
except Exception as e:
logger.error(f"Error searching records with prefix '{prefix}': {e}")
return []
@final
@dataclass
@@ -968,4 +1017,20 @@ SQL_TEMPLATES = {
WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace
""",
# Search by prefix SQL templates
"search_entity_by_prefix": """
SELECT entity_id as id, name as entity_name, entity_type, description, content
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id LIKE :prefix_pattern AND workspace = :workspace
""",
"search_relationship_by_prefix": """
SELECT relation_id as id, source_name as src_id, target_name as tgt_id, keywords, description, content
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id LIKE :prefix_pattern AND workspace = :workspace
""",
"search_chunk_by_prefix": """
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
""",
}