diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 99e4f5c4..d2630659 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -432,19 +432,31 @@ class PGVectorStorage(BaseVectorStorage): def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]: upsert_sql = SQL_TEMPLATES["upsert_entity"] + source_id = item["source_id"] + if isinstance(source_id, str) and "" in source_id: + chunk_ids = source_id.split("") + else: + chunk_ids = [source_id] + data: dict[str, Any] = { "workspace": self.db.workspace, "id": item["__id__"], "entity_name": item["entity_name"], "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_id": item["source_id"], + "chunk_ids": chunk_ids, # TODO: add document_id } return upsert_sql, data def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]: upsert_sql = SQL_TEMPLATES["upsert_relationship"] + source_id = item["source_id"] + if isinstance(source_id, str) and "" in source_id: + chunk_ids = source_id.split("") + else: + chunk_ids = [source_id] + data: dict[str, Any] = { "workspace": self.db.workspace, "id": item["__id__"], @@ -452,7 +464,7 @@ class PGVectorStorage(BaseVectorStorage): "target_id": item["tgt_id"], "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_id": item["source_id"], + "chunk_ids": chunk_ids, # TODO: add document_id } return upsert_sql, data @@ -950,10 +962,11 @@ class PGGraphStorage(BaseGraphStorage): vertices.get(edge["end_id"], {}), ) else: - if v is None or (v.count("{") < 1 and v.count("[") < 1): - d[k] = v - else: - d[k] = json.loads(v) if isinstance(v, str) else v + d[k] = ( + json.loads(v) + if isinstance(v, str) and ("{" in v or "[" in v) + else v + ) return d @@ -1410,29 +1423,22 @@ class PGGraphStorage(BaseGraphStorage): """ MAX_GRAPH_NODES = 1000 + # Build the query based on whether we want the full graph or a specific subgraph. if node_label == "*": - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - OPTIONAL MATCH (n)-[r]->(m:Entity) - RETURN n, r, m - LIMIT %d - $$) AS (n agtype, r agtype, m agtype)""" % ( - self.graph_name, - MAX_GRAPH_NODES, - ) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]->(m:Entity) + RETURN n, r, m + LIMIT {MAX_GRAPH_NODES} + $$) AS (n agtype, r agtype, m agtype)""" else: - encoded_node_label = self._encode_graph_label(node_label.strip('"')) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) - OPTIONAL MATCH p = (n)-[*..%d]-(m) - RETURN nodes(p) AS nodes, relationships(p) AS relationships - LIMIT %d - $$) AS (nodes agtype, relationships agtype)""" % ( - self.graph_name, - encoded_node_label, - max_depth, - MAX_GRAPH_NODES, - ) + encoded_label = self._encode_graph_label(node_label.strip('"')) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:Entity {{node_id: "{encoded_label}"}}) + OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) + RETURN nodes(p) AS nodes, relationships(p) AS relationships + LIMIT {MAX_GRAPH_NODES} + $$) AS (nodes agtype, relationships agtype)""" results = await self._query(query) @@ -1440,56 +1446,57 @@ class PGGraphStorage(BaseGraphStorage): edges = [] unique_edge_ids = set() - for result in results: - if node_label == "*": - if result["n"]: - node = result["n"] - node_id = self._decode_graph_label(node["node_id"]) - if node_id not in nodes: - nodes[node_id] = node + def add_node(node_data: dict): + node_id = self._decode_graph_label(node_data["node_id"]) + if node_id not in nodes: + nodes[node_id] = node_data - if result["m"]: - node = result["m"] - node_id = self._decode_graph_label(node["node_id"]) - if node_id not in nodes: - nodes[node_id] = node - if result["r"]: - edge = result["r"] - src_id = self._decode_graph_label(edge["start_id"]) - tgt_id = self._decode_graph_label(edge["end_id"]) - edges.append((src_id, tgt_id)) - else: - if result["nodes"]: - for node in result["nodes"]: - node_id = self._decode_graph_label(node["node_id"]) - if node_id not in nodes: - nodes[node_id] = node + def add_edge(edge_data: list): + src_id = self._decode_graph_label(edge_data[0]["node_id"]) + tgt_id = self._decode_graph_label(edge_data[2]["node_id"]) + edge_key = f"{src_id},{tgt_id}" + if edge_key not in unique_edge_ids: + unique_edge_ids.add(edge_key) + edges.append( + ( + edge_key, + src_id, + tgt_id, + {"source": edge_data[0], "target": edge_data[2]}, + ) + ) - if result["relationships"]: - for edge in result["relationships"]: # src --DIRECTED--> target - src_id = self._decode_graph_label(edge[0]["node_id"]) - tgt_id = self._decode_graph_label(edge[2]["node_id"]) - id = src_id + "," + tgt_id - if id in unique_edge_ids: - continue - else: - unique_edge_ids.add(id) - edges.append( - (id, src_id, tgt_id, {"source": edge[0], "target": edge[2]}) - ) + # Process the query results. + if node_label == "*": + for result in results: + if result.get("n"): + add_node(result["n"]) + if result.get("m"): + add_node(result["m"]) + if result.get("r"): + add_edge(result["r"]) + else: + for result in results: + for node in result.get("nodes", []): + add_node(node) + for edge in result.get("relationships", []): + add_edge(edge) + # Construct and return the KnowledgeGraph. kg = KnowledgeGraph( nodes=[ - KnowledgeGraphNode( - id=node_id, labels=[node_id], properties=nodes[node_id] - ) - for node_id in nodes + KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data) + for node_id, node_data in nodes.items() ], edges=[ KnowledgeGraphEdge( - id=id, type="DIRECTED", source=src, target=tgt, properties=props + id=edge_id, + type="DIRECTED", + source=src, + target=tgt, + properties=props, ) - for id, src, tgt, props in edges + for edge_id, src, tgt, props in edges ], ) @@ -1654,22 +1661,25 @@ SQL_TEMPLATES = { update_time = CURRENT_TIMESTAMP """, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, - content_vector, chunk_id) - VALUES ($1, $2, $3, $4, $5, $6) + content_vector, chunk_ids) + VALUES ($1, $2, $3, $4, $5, $6::varchar[]) ON CONFLICT (workspace,id) DO UPDATE SET entity_name=EXCLUDED.entity_name, content=EXCLUDED.content, content_vector=EXCLUDED.content_vector, + chunk_ids=EXCLUDED.chunk_ids, update_time=CURRENT_TIMESTAMP """, "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, - target_id, content, content_vector, chunk_id) - VALUES ($1, $2, $3, $4, $5, $6, $7) + target_id, content, content_vector, chunk_ids) + VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[]) ON CONFLICT (workspace,id) DO UPDATE SET source_id=EXCLUDED.source_id, target_id=EXCLUDED.target_id, content=EXCLUDED.content, - content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP + content_vector=EXCLUDED.content_vector, + chunk_ids=EXCLUDED.chunk_ids, + update_time = CURRENT_TIMESTAMP """, # SQL for VectorStorage # "entities": """SELECT entity_name FROM @@ -1720,8 +1730,8 @@ SQL_TEMPLATES = { FROM ( SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance FROM LIGHTRAG_VDB_RELATION r + JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids) WHERE r.workspace=$1 - AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks) ) filtered WHERE distance>$2 ORDER BY distance DESC @@ -1735,11 +1745,11 @@ SQL_TEMPLATES = { ) SELECT entity_name FROM ( - SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_ENTITY - where workspace=$1 - AND chunk_id IN (SELECT chunk_id FROM relevant_chunks) - ) as chunk_distances + SELECT e.id, e.entity_name, 1 - (e.content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_ENTITY e + JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids) + WHERE e.workspace=$1 + ) WHERE distance>$2 ORDER BY distance DESC LIMIT $3 diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 555fea90..70aa0ceb 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -123,21 +123,18 @@ async def openai_complete_if_cache( async def inner(): try: - _content = "" async for chunk in response: content = chunk.choices[0].delta.content if content is None: continue if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) - _content += content - return _content + yield content except Exception as e: logger.error(f"Error in stream response: {str(e)}") raise - response_content = await inner() - return response_content + return inner() else: if ( diff --git a/lightrag/operate.py b/lightrag/operate.py index 1815f308..d062ae73 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -404,7 +404,7 @@ async def extract_entities( language=language, ) - continue_prompt = PROMPTS["entity_continue_extraction"] + continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base) if_loop_prompt = PROMPTS["entity_if_loop_extraction"] processed_chunks = 0