Merge branch 'main' into main

This commit is contained in:
zrguo
2025-03-17 17:23:41 +08:00
committed by GitHub
3 changed files with 92 additions and 85 deletions

View File

@@ -432,19 +432,31 @@ class PGVectorStorage(BaseVectorStorage):
def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]: def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_entity"] upsert_sql = SQL_TEMPLATES["upsert_entity"]
source_id = item["source_id"]
if isinstance(source_id, str) and "<SEP>" in source_id:
chunk_ids = source_id.split("<SEP>")
else:
chunk_ids = [source_id]
data: dict[str, Any] = { data: dict[str, Any] = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": item["__id__"], "id": item["__id__"],
"entity_name": item["entity_name"], "entity_name": item["entity_name"],
"content": item["content"], "content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()), "content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_id": item["source_id"], "chunk_ids": chunk_ids,
# TODO: add document_id # TODO: add document_id
} }
return upsert_sql, data return upsert_sql, data
def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]: def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_relationship"] upsert_sql = SQL_TEMPLATES["upsert_relationship"]
source_id = item["source_id"]
if isinstance(source_id, str) and "<SEP>" in source_id:
chunk_ids = source_id.split("<SEP>")
else:
chunk_ids = [source_id]
data: dict[str, Any] = { data: dict[str, Any] = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": item["__id__"], "id": item["__id__"],
@@ -452,7 +464,7 @@ class PGVectorStorage(BaseVectorStorage):
"target_id": item["tgt_id"], "target_id": item["tgt_id"],
"content": item["content"], "content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()), "content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_id": item["source_id"], "chunk_ids": chunk_ids,
# TODO: add document_id # TODO: add document_id
} }
return upsert_sql, data return upsert_sql, data
@@ -950,10 +962,11 @@ class PGGraphStorage(BaseGraphStorage):
vertices.get(edge["end_id"], {}), vertices.get(edge["end_id"], {}),
) )
else: else:
if v is None or (v.count("{") < 1 and v.count("[") < 1): d[k] = (
d[k] = v json.loads(v)
else: if isinstance(v, str) and ("{" in v or "[" in v)
d[k] = json.loads(v) if isinstance(v, str) else v else v
)
return d return d
@@ -1410,29 +1423,22 @@ class PGGraphStorage(BaseGraphStorage):
""" """
MAX_GRAPH_NODES = 1000 MAX_GRAPH_NODES = 1000
# Build the query based on whether we want the full graph or a specific subgraph.
if node_label == "*": if node_label == "*":
query = """SELECT * FROM cypher('%s', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity) MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity) OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m RETURN n, r, m
LIMIT %d LIMIT {MAX_GRAPH_NODES}
$$) AS (n agtype, r agtype, m agtype)""" % ( $$) AS (n agtype, r agtype, m agtype)"""
self.graph_name,
MAX_GRAPH_NODES,
)
else: else:
encoded_node_label = self._encode_graph_label(node_label.strip('"')) encoded_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {{node_id: "{encoded_label}"}})
OPTIONAL MATCH p = (n)-[*..%d]-(m) OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d LIMIT {MAX_GRAPH_NODES}
$$) AS (nodes agtype, relationships agtype)""" % ( $$) AS (nodes agtype, relationships agtype)"""
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
results = await self._query(query) results = await self._query(query)
@@ -1440,56 +1446,57 @@ class PGGraphStorage(BaseGraphStorage):
edges = [] edges = []
unique_edge_ids = set() unique_edge_ids = set()
for result in results: def add_node(node_data: dict):
if node_label == "*": node_id = self._decode_graph_label(node_data["node_id"])
if result["n"]: if node_id not in nodes:
node = result["n"] nodes[node_id] = node_data
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["m"]: def add_edge(edge_data: list):
node = result["m"] src_id = self._decode_graph_label(edge_data[0]["node_id"])
node_id = self._decode_graph_label(node["node_id"]) tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
if node_id not in nodes: edge_key = f"{src_id},{tgt_id}"
nodes[node_id] = node if edge_key not in unique_edge_ids:
if result["r"]: unique_edge_ids.add(edge_key)
edge = result["r"] edges.append(
src_id = self._decode_graph_label(edge["start_id"]) (
tgt_id = self._decode_graph_label(edge["end_id"]) edge_key,
edges.append((src_id, tgt_id)) src_id,
else: tgt_id,
if result["nodes"]: {"source": edge_data[0], "target": edge_data[2]},
for node in result["nodes"]: )
node_id = self._decode_graph_label(node["node_id"]) )
if node_id not in nodes:
nodes[node_id] = node
if result["relationships"]: # Process the query results.
for edge in result["relationships"]: # src --DIRECTED--> target if node_label == "*":
src_id = self._decode_graph_label(edge[0]["node_id"]) for result in results:
tgt_id = self._decode_graph_label(edge[2]["node_id"]) if result.get("n"):
id = src_id + "," + tgt_id add_node(result["n"])
if id in unique_edge_ids: if result.get("m"):
continue add_node(result["m"])
else: if result.get("r"):
unique_edge_ids.add(id) add_edge(result["r"])
edges.append( else:
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]}) for result in results:
) for node in result.get("nodes", []):
add_node(node)
for edge in result.get("relationships", []):
add_edge(edge)
# Construct and return the KnowledgeGraph.
kg = KnowledgeGraph( kg = KnowledgeGraph(
nodes=[ nodes=[
KnowledgeGraphNode( KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
id=node_id, labels=[node_id], properties=nodes[node_id] for node_id, node_data in nodes.items()
)
for node_id in nodes
], ],
edges=[ edges=[
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=id, type="DIRECTED", source=src, target=tgt, properties=props id=edge_id,
type="DIRECTED",
source=src,
target=tgt,
properties=props,
) )
for id, src, tgt, props in edges for edge_id, src, tgt, props in edges
], ],
) )
@@ -1654,22 +1661,25 @@ SQL_TEMPLATES = {
update_time = CURRENT_TIMESTAMP update_time = CURRENT_TIMESTAMP
""", """,
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_id) content_vector, chunk_ids)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6::varchar[])
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET entity_name=EXCLUDED.entity_name, SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content, content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector, content_vector=EXCLUDED.content_vector,
chunk_ids=EXCLUDED.chunk_ids,
update_time=CURRENT_TIMESTAMP update_time=CURRENT_TIMESTAMP
""", """,
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
target_id, content, content_vector, chunk_id) target_id, content, content_vector, chunk_ids)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[])
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET source_id=EXCLUDED.source_id, SET source_id=EXCLUDED.source_id,
target_id=EXCLUDED.target_id, target_id=EXCLUDED.target_id,
content=EXCLUDED.content, content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP content_vector=EXCLUDED.content_vector,
chunk_ids=EXCLUDED.chunk_ids,
update_time = CURRENT_TIMESTAMP
""", """,
# SQL for VectorStorage # SQL for VectorStorage
# "entities": """SELECT entity_name FROM # "entities": """SELECT entity_name FROM
@@ -1720,8 +1730,8 @@ SQL_TEMPLATES = {
FROM ( FROM (
SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_VDB_RELATION r FROM LIGHTRAG_VDB_RELATION r
JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids)
WHERE r.workspace=$1 WHERE r.workspace=$1
AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
) filtered ) filtered
WHERE distance>$2 WHERE distance>$2
ORDER BY distance DESC ORDER BY distance DESC
@@ -1735,11 +1745,11 @@ SQL_TEMPLATES = {
) )
SELECT entity_name FROM SELECT entity_name FROM
( (
SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance SELECT e.id, e.entity_name, 1 - (e.content_vector <=> '[{embedding_string}]'::vector) as distance
FROM LIGHTRAG_VDB_ENTITY FROM LIGHTRAG_VDB_ENTITY e
where workspace=$1 JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids)
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks) WHERE e.workspace=$1
) as chunk_distances )
WHERE distance>$2 WHERE distance>$2
ORDER BY distance DESC ORDER BY distance DESC
LIMIT $3 LIMIT $3

View File

@@ -123,21 +123,18 @@ async def openai_complete_if_cache(
async def inner(): async def inner():
try: try:
_content = ""
async for chunk in response: async for chunk in response:
content = chunk.choices[0].delta.content content = chunk.choices[0].delta.content
if content is None: if content is None:
continue continue
if r"\u" in content: if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8")) content = safe_unicode_decode(content.encode("utf-8"))
_content += content yield content
return _content
except Exception as e: except Exception as e:
logger.error(f"Error in stream response: {str(e)}") logger.error(f"Error in stream response: {str(e)}")
raise raise
response_content = await inner() return inner()
return response_content
else: else:
if ( if (

View File

@@ -404,7 +404,7 @@ async def extract_entities(
language=language, language=language,
) )
continue_prompt = PROMPTS["entity_continue_extraction"] continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base)
if_loop_prompt = PROMPTS["entity_if_loop_extraction"] if_loop_prompt = PROMPTS["entity_if_loop_extraction"]
processed_chunks = 0 processed_chunks = 0