Merge branch 'main' into main
This commit is contained in:
@@ -432,19 +432,31 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
||||||
|
source_id = item["source_id"]
|
||||||
|
if isinstance(source_id, str) and "<SEP>" in source_id:
|
||||||
|
chunk_ids = source_id.split("<SEP>")
|
||||||
|
else:
|
||||||
|
chunk_ids = [source_id]
|
||||||
|
|
||||||
data: dict[str, Any] = {
|
data: dict[str, Any] = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"id": item["__id__"],
|
"id": item["__id__"],
|
||||||
"entity_name": item["entity_name"],
|
"entity_name": item["entity_name"],
|
||||||
"content": item["content"],
|
"content": item["content"],
|
||||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||||
"chunk_id": item["source_id"],
|
"chunk_ids": chunk_ids,
|
||||||
# TODO: add document_id
|
# TODO: add document_id
|
||||||
}
|
}
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
|
|
||||||
def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
||||||
|
source_id = item["source_id"]
|
||||||
|
if isinstance(source_id, str) and "<SEP>" in source_id:
|
||||||
|
chunk_ids = source_id.split("<SEP>")
|
||||||
|
else:
|
||||||
|
chunk_ids = [source_id]
|
||||||
|
|
||||||
data: dict[str, Any] = {
|
data: dict[str, Any] = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"id": item["__id__"],
|
"id": item["__id__"],
|
||||||
@@ -452,7 +464,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
"target_id": item["tgt_id"],
|
"target_id": item["tgt_id"],
|
||||||
"content": item["content"],
|
"content": item["content"],
|
||||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||||
"chunk_id": item["source_id"],
|
"chunk_ids": chunk_ids,
|
||||||
# TODO: add document_id
|
# TODO: add document_id
|
||||||
}
|
}
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
@@ -950,10 +962,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
vertices.get(edge["end_id"], {}),
|
vertices.get(edge["end_id"], {}),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if v is None or (v.count("{") < 1 and v.count("[") < 1):
|
d[k] = (
|
||||||
d[k] = v
|
json.loads(v)
|
||||||
else:
|
if isinstance(v, str) and ("{" in v or "[" in v)
|
||||||
d[k] = json.loads(v) if isinstance(v, str) else v
|
else v
|
||||||
|
)
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@@ -1410,29 +1423,22 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
"""
|
"""
|
||||||
MAX_GRAPH_NODES = 1000
|
MAX_GRAPH_NODES = 1000
|
||||||
|
|
||||||
|
# Build the query based on whether we want the full graph or a specific subgraph.
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
OPTIONAL MATCH (n)-[r]->(m:Entity)
|
OPTIONAL MATCH (n)-[r]->(m:Entity)
|
||||||
RETURN n, r, m
|
RETURN n, r, m
|
||||||
LIMIT %d
|
LIMIT {MAX_GRAPH_NODES}
|
||||||
$$) AS (n agtype, r agtype, m agtype)""" % (
|
$$) AS (n agtype, r agtype, m agtype)"""
|
||||||
self.graph_name,
|
|
||||||
MAX_GRAPH_NODES,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
|
encoded_label = self._encode_graph_label(node_label.strip('"'))
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (n:Entity {node_id: "%s"})
|
MATCH (n:Entity {{node_id: "{encoded_label}"}})
|
||||||
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
||||||
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
||||||
LIMIT %d
|
LIMIT {MAX_GRAPH_NODES}
|
||||||
$$) AS (nodes agtype, relationships agtype)""" % (
|
$$) AS (nodes agtype, relationships agtype)"""
|
||||||
self.graph_name,
|
|
||||||
encoded_node_label,
|
|
||||||
max_depth,
|
|
||||||
MAX_GRAPH_NODES,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = await self._query(query)
|
results = await self._query(query)
|
||||||
|
|
||||||
@@ -1440,56 +1446,57 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
edges = []
|
edges = []
|
||||||
unique_edge_ids = set()
|
unique_edge_ids = set()
|
||||||
|
|
||||||
for result in results:
|
def add_node(node_data: dict):
|
||||||
if node_label == "*":
|
node_id = self._decode_graph_label(node_data["node_id"])
|
||||||
if result["n"]:
|
if node_id not in nodes:
|
||||||
node = result["n"]
|
nodes[node_id] = node_data
|
||||||
node_id = self._decode_graph_label(node["node_id"])
|
|
||||||
if node_id not in nodes:
|
|
||||||
nodes[node_id] = node
|
|
||||||
|
|
||||||
if result["m"]:
|
def add_edge(edge_data: list):
|
||||||
node = result["m"]
|
src_id = self._decode_graph_label(edge_data[0]["node_id"])
|
||||||
node_id = self._decode_graph_label(node["node_id"])
|
tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
|
||||||
if node_id not in nodes:
|
edge_key = f"{src_id},{tgt_id}"
|
||||||
nodes[node_id] = node
|
if edge_key not in unique_edge_ids:
|
||||||
if result["r"]:
|
unique_edge_ids.add(edge_key)
|
||||||
edge = result["r"]
|
edges.append(
|
||||||
src_id = self._decode_graph_label(edge["start_id"])
|
(
|
||||||
tgt_id = self._decode_graph_label(edge["end_id"])
|
edge_key,
|
||||||
edges.append((src_id, tgt_id))
|
src_id,
|
||||||
else:
|
tgt_id,
|
||||||
if result["nodes"]:
|
{"source": edge_data[0], "target": edge_data[2]},
|
||||||
for node in result["nodes"]:
|
)
|
||||||
node_id = self._decode_graph_label(node["node_id"])
|
)
|
||||||
if node_id not in nodes:
|
|
||||||
nodes[node_id] = node
|
|
||||||
|
|
||||||
if result["relationships"]:
|
# Process the query results.
|
||||||
for edge in result["relationships"]: # src --DIRECTED--> target
|
if node_label == "*":
|
||||||
src_id = self._decode_graph_label(edge[0]["node_id"])
|
for result in results:
|
||||||
tgt_id = self._decode_graph_label(edge[2]["node_id"])
|
if result.get("n"):
|
||||||
id = src_id + "," + tgt_id
|
add_node(result["n"])
|
||||||
if id in unique_edge_ids:
|
if result.get("m"):
|
||||||
continue
|
add_node(result["m"])
|
||||||
else:
|
if result.get("r"):
|
||||||
unique_edge_ids.add(id)
|
add_edge(result["r"])
|
||||||
edges.append(
|
else:
|
||||||
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
|
for result in results:
|
||||||
)
|
for node in result.get("nodes", []):
|
||||||
|
add_node(node)
|
||||||
|
for edge in result.get("relationships", []):
|
||||||
|
add_edge(edge)
|
||||||
|
|
||||||
|
# Construct and return the KnowledgeGraph.
|
||||||
kg = KnowledgeGraph(
|
kg = KnowledgeGraph(
|
||||||
nodes=[
|
nodes=[
|
||||||
KnowledgeGraphNode(
|
KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
|
||||||
id=node_id, labels=[node_id], properties=nodes[node_id]
|
for node_id, node_data in nodes.items()
|
||||||
)
|
|
||||||
for node_id in nodes
|
|
||||||
],
|
],
|
||||||
edges=[
|
edges=[
|
||||||
KnowledgeGraphEdge(
|
KnowledgeGraphEdge(
|
||||||
id=id, type="DIRECTED", source=src, target=tgt, properties=props
|
id=edge_id,
|
||||||
|
type="DIRECTED",
|
||||||
|
source=src,
|
||||||
|
target=tgt,
|
||||||
|
properties=props,
|
||||||
)
|
)
|
||||||
for id, src, tgt, props in edges
|
for edge_id, src, tgt, props in edges
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1654,22 +1661,25 @@ SQL_TEMPLATES = {
|
|||||||
update_time = CURRENT_TIMESTAMP
|
update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||||
content_vector, chunk_id)
|
content_vector, chunk_ids)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6::varchar[])
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET entity_name=EXCLUDED.entity_name,
|
SET entity_name=EXCLUDED.entity_name,
|
||||||
content=EXCLUDED.content,
|
content=EXCLUDED.content,
|
||||||
content_vector=EXCLUDED.content_vector,
|
content_vector=EXCLUDED.content_vector,
|
||||||
|
chunk_ids=EXCLUDED.chunk_ids,
|
||||||
update_time=CURRENT_TIMESTAMP
|
update_time=CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
||||||
target_id, content, content_vector, chunk_id)
|
target_id, content, content_vector, chunk_ids)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[])
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET source_id=EXCLUDED.source_id,
|
SET source_id=EXCLUDED.source_id,
|
||||||
target_id=EXCLUDED.target_id,
|
target_id=EXCLUDED.target_id,
|
||||||
content=EXCLUDED.content,
|
content=EXCLUDED.content,
|
||||||
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
|
content_vector=EXCLUDED.content_vector,
|
||||||
|
chunk_ids=EXCLUDED.chunk_ids,
|
||||||
|
update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
# "entities": """SELECT entity_name FROM
|
# "entities": """SELECT entity_name FROM
|
||||||
@@ -1720,8 +1730,8 @@ SQL_TEMPLATES = {
|
|||||||
FROM (
|
FROM (
|
||||||
SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
|
SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
FROM LIGHTRAG_VDB_RELATION r
|
FROM LIGHTRAG_VDB_RELATION r
|
||||||
|
JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids)
|
||||||
WHERE r.workspace=$1
|
WHERE r.workspace=$1
|
||||||
AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
|
||||||
) filtered
|
) filtered
|
||||||
WHERE distance>$2
|
WHERE distance>$2
|
||||||
ORDER BY distance DESC
|
ORDER BY distance DESC
|
||||||
@@ -1735,11 +1745,11 @@ SQL_TEMPLATES = {
|
|||||||
)
|
)
|
||||||
SELECT entity_name FROM
|
SELECT entity_name FROM
|
||||||
(
|
(
|
||||||
SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
SELECT e.id, e.entity_name, 1 - (e.content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||||
FROM LIGHTRAG_VDB_ENTITY
|
FROM LIGHTRAG_VDB_ENTITY e
|
||||||
where workspace=$1
|
JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids)
|
||||||
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
WHERE e.workspace=$1
|
||||||
) as chunk_distances
|
)
|
||||||
WHERE distance>$2
|
WHERE distance>$2
|
||||||
ORDER BY distance DESC
|
ORDER BY distance DESC
|
||||||
LIMIT $3
|
LIMIT $3
|
||||||
|
@@ -123,21 +123,18 @@ async def openai_complete_if_cache(
|
|||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
try:
|
try:
|
||||||
_content = ""
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
content = chunk.choices[0].delta.content
|
content = chunk.choices[0].delta.content
|
||||||
if content is None:
|
if content is None:
|
||||||
continue
|
continue
|
||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = safe_unicode_decode(content.encode("utf-8"))
|
content = safe_unicode_decode(content.encode("utf-8"))
|
||||||
_content += content
|
yield content
|
||||||
return _content
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in stream response: {str(e)}")
|
logger.error(f"Error in stream response: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
response_content = await inner()
|
return inner()
|
||||||
return response_content
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
|
@@ -404,7 +404,7 @@ async def extract_entities(
|
|||||||
language=language,
|
language=language,
|
||||||
)
|
)
|
||||||
|
|
||||||
continue_prompt = PROMPTS["entity_continue_extraction"]
|
continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base)
|
||||||
if_loop_prompt = PROMPTS["entity_if_loop_extraction"]
|
if_loop_prompt = PROMPTS["entity_if_loop_extraction"]
|
||||||
|
|
||||||
processed_chunks = 0
|
processed_chunks = 0
|
||||||
|
Reference in New Issue
Block a user