diff --git a/README.md b/README.md
index 74c40f15..6d5af135 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@
-
+
@@ -16,27 +16,34 @@
+
+
+
+
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).

## 🎉 News
-- [x] [2024.11.12]🎯📢You can [use Oracle Database 23ai for all storage types (kv/vector/graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py) now.
+- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author!
+- [x] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
- [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity).
- [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
- [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
- [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
- [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization.
- [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
-- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
+- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/yF2MmDJyGJ)! Welcome to join for sharing and discussions! 🎉🎉
- [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
## Algorithm Flowchart
-
-
+
+*Figure 1: LightRAG Indexing Flowchart*
+
+*Figure 2: LightRAG Retrieval and Querying Flowchart*
## Install
diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py
index 3b2cafc6..c06b8a83 100644
--- a/examples/lightrag_api_oracle_demo..py
+++ b/examples/lightrag_api_oracle_demo..py
@@ -162,12 +162,12 @@ class Response(BaseModel):
# API routes
-rag = None # 定义为全局对象
+rag = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
- rag = await init() # 在应用启动时初始化 `rag`
+ rag = await init()
print("done!")
yield
diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py
index e29a6a9d..6a3fafc1 100644
--- a/examples/lightrag_azure_openai_demo.py
+++ b/examples/lightrag_azure_openai_demo.py
@@ -4,8 +4,8 @@ from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
from dotenv import load_dotenv
-import aiohttp
import logging
+from openai import AzureOpenAI
logging.basicConfig(level=logging.INFO)
@@ -32,11 +32,11 @@ os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
- headers = {
- "Content-Type": "application/json",
- "api-key": AZURE_OPENAI_API_KEY,
- }
- endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}"
+ client = AzureOpenAI(
+ api_key=AZURE_OPENAI_API_KEY,
+ api_version=AZURE_OPENAI_API_VERSION,
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
+ )
messages = []
if system_prompt:
@@ -45,41 +45,26 @@ async def llm_model_func(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
- payload = {
- "messages": messages,
- "temperature": kwargs.get("temperature", 0),
- "top_p": kwargs.get("top_p", 1),
- "n": kwargs.get("n", 1),
- }
-
- async with aiohttp.ClientSession() as session:
- async with session.post(endpoint, headers=headers, json=payload) as response:
- if response.status != 200:
- raise ValueError(
- f"Request failed with status {response.status}: {await response.text()}"
- )
- result = await response.json()
- return result["choices"][0]["message"]["content"]
+ chat_completion = client.chat.completions.create(
+ model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
+ messages=messages,
+ temperature=kwargs.get("temperature", 0),
+ top_p=kwargs.get("top_p", 1),
+ n=kwargs.get("n", 1),
+ )
+ return chat_completion.choices[0].message.content
async def embedding_func(texts: list[str]) -> np.ndarray:
- headers = {
- "Content-Type": "application/json",
- "api-key": AZURE_OPENAI_API_KEY,
- }
- endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}"
+ client = AzureOpenAI(
+ api_key=AZURE_OPENAI_API_KEY,
+ api_version=AZURE_EMBEDDING_API_VERSION,
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
+ )
+ embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
- payload = {"input": texts}
-
- async with aiohttp.ClientSession() as session:
- async with session.post(endpoint, headers=headers, json=payload) as response:
- if response.status != 200:
- raise ValueError(
- f"Request failed with status {response.status}: {await response.text()}"
- )
- result = await response.json()
- embeddings = [item["embedding"] for item in result["data"]]
- return np.array(embeddings)
+ embeddings = [item.embedding for item in embedding.data]
+ return np.array(embeddings)
async def test_funcs():
diff --git a/lightrag/__init__.py b/lightrag/__init__.py
index 6d9003ff..c8b61765 100644
--- a/lightrag/__init__.py
+++ b/lightrag/__init__.py
@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
-__version__ = "1.0.0"
+__version__ = "1.0.1"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py
index e6b33a9b..84efae81 100644
--- a/lightrag/kg/neo4j_impl.py
+++ b/lightrag/kg/neo4j_impl.py
@@ -86,9 +86,6 @@ class Neo4JStorage(BaseGraphStorage):
)
return single_result["edgeExists"]
- def close(self):
- self._driver.close()
-
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
entity_name_label = node_id.strip('"')
@@ -214,6 +211,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
+ neo4jExceptions.ClientError,
)
),
)
diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py
index 2e394b8a..2cfbd249 100644
--- a/lightrag/kg/oracle_impl.py
+++ b/lightrag/kg/oracle_impl.py
@@ -114,7 +114,7 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database")
- async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
+ async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
@@ -173,10 +173,11 @@ class OracleKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据."""
- SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
- params = {"workspace":self.db.workspace, "id":id}
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
+ workspace=self.db.workspace, id=id
+ )
# print("get_by_id:"+SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL)
if res:
data = res # {"data":res}
# print (data)
@@ -187,11 +188,11 @@ class OracleKVStorage(BaseKVStorage):
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
- params = {"workspace":self.db.workspace}
- #print("get_by_ids:"+SQL)
- #print(params)
- res = await self.db.query(SQL,params, multirows=True)
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
+ workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
+ )
+ # print("get_by_ids:"+SQL)
+ res = await self.db.query(SQL, multirows=True)
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -201,16 +202,12 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
- SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
- ids=",".join([f"'{id}'" for id in keys]))
- params = {"workspace":self.db.workspace}
- try:
- await self.db.query(SQL, params)
- except Exception as e:
- logger.error(f"Oracle database error: {e}")
- print(SQL)
- print(params)
- res = await self.db.query(SQL, params,multirows=True)
+ SQL = SQL_TEMPLATES["filter_keys"].format(
+ table_name=N_T[self.namespace],
+ workspace=self.db.workspace,
+ ids=",".join([f"'{k}'" for k in keys]),
+ )
+ res = await self.db.query(SQL, multirows=True)
data = None
if res:
exist_keys = [key["id"] for key in res]
@@ -247,29 +244,27 @@ class OracleKVStorage(BaseKVStorage):
d["__vector__"] = embeddings[i]
# print(list_data)
for item in list_data:
- merge_sql = SQL_TEMPLATES["merge_chunk"]
- data = {"check_id":item["__id__"],
- "id":item["__id__"],
- "content":item["content"],
- "workspace":self.db.workspace,
- "tokens":item["tokens"],
- "chunk_order_index":item["chunk_order_index"],
- "full_doc_id":item["full_doc_id"],
- "content_vector":item["__vector__"]
- }
+ merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
+
+ values = [
+ item["__id__"],
+ item["content"],
+ self.db.workspace,
+ item["tokens"],
+ item["chunk_order_index"],
+ item["full_doc_id"],
+ item["__vector__"],
+ ]
# print(merge_sql)
await self.db.execute(merge_sql, data)
if self.namespace == "full_docs":
for k, v in self._data.items():
# values.clear()
- merge_sql = SQL_TEMPLATES["merge_doc_full"]
- data = {
- "check_id":k,
- "id":k,
- "content":v["content"],
- "workspace":self.db.workspace
- }
+ merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
+ check_id=k,
+ )
+ values = [k, self._data[k]["content"], self.db.workspace]
# print(merge_sql)
await self.db.execute(merge_sql, data)
return left_data
@@ -301,17 +296,18 @@ class OracleVectorDBStorage(BaseVectorStorage):
# 转换精度
dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0]
- embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
+ embedding_string = ", ".join(map(str, embedding.tolist()))
- SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
- params = {
- "embedding_string": embedding_string,
- "workspace": self.db.workspace,
- "top_k": top_k,
- "better_than_threshold": self.cosine_better_than_threshold,
- }
+ SQL = SQL_TEMPLATES[self.namespace].format(
+ embedding_string=embedding_string,
+ dimension=dimension,
+ dtype=dtype,
+ workspace=self.db.workspace,
+ top_k=top_k,
+ better_than_threshold=self.cosine_better_than_threshold,
+ )
# print(SQL)
- results = await self.db.query(SQL,params=params, multirows=True)
+ results = await self.db.query(SQL, multirows=True)
# print("vector search result:",results)
return results
@@ -346,18 +342,22 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
- merge_sql = SQL_TEMPLATES["merge_node"]
- data = {
- "workspace":self.db.workspace,
- "name":entity_name,
- "entity_type":entity_type,
- "description":description,
- "source_chunk_id":source_id,
- "content":content,
- "content_vector":content_vector
- }
+ merge_sql = SQL_TEMPLATES["merge_node"].format(
+ workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
+ )
# print(merge_sql)
- await self.db.execute(merge_sql,data)
+ await self.db.execute(
+ merge_sql,
+ [
+ self.db.workspace,
+ entity_name,
+ entity_type,
+ description,
+ source_id,
+ content,
+ content_vector,
+ ],
+ )
# self._graph.add_node(node_id, **node_data)
async def upsert_edge(
@@ -371,8 +371,6 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
- logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
-
content = keywords + source_name + target_name + description
contents = [content]
batches = [
@@ -384,20 +382,27 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
- merge_sql = SQL_TEMPLATES["merge_edge"]
- data = {
- "workspace":self.db.workspace,
- "source_name":source_name,
- "target_name":target_name,
- "weight":weight,
- "keywords":keywords,
- "description":description,
- "source_chunk_id":source_chunk_id,
- "content":content,
- "content_vector":content_vector
- }
+ merge_sql = SQL_TEMPLATES["merge_edge"].format(
+ workspace=self.db.workspace,
+ source_name=source_name,
+ target_name=target_name,
+ source_chunk_id=source_chunk_id,
+ )
# print(merge_sql)
- await self.db.execute(merge_sql,data)
+ await self.db.execute(
+ merge_sql,
+ [
+ self.db.workspace,
+ source_name,
+ target_name,
+ weight,
+ keywords,
+ description,
+ source_chunk_id,
+ content,
+ content_vector,
+ ],
+ )
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -427,14 +432,12 @@ class OracleGraphStorage(BaseGraphStorage):
#################### query method #################
async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在"""
- SQL = SQL_TEMPLATES["has_node"]
- params = {
- "workspace":self.db.workspace,
- "node_id":node_id
- }
+ SQL = SQL_TEMPLATES["has_node"].format(
+ workspace=self.db.workspace, node_id=node_id
+ )
# print(SQL)
# print(self.db.workspace, node_id)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL)
if res:
# print("Node exist!",res)
return True
@@ -444,14 +447,13 @@ class OracleGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在"""
- SQL = SQL_TEMPLATES["has_edge"]
- params = {
- "workspace":self.db.workspace,
- "source_node_id":source_node_id,
- "target_node_id":target_node_id
- }
+ SQL = SQL_TEMPLATES["has_edge"].format(
+ workspace=self.db.workspace,
+ source_node_id=source_node_id,
+ target_node_id=target_node_id,
+ )
# print(SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL)
if res:
# print("Edge exist!",res)
return True
@@ -461,13 +463,11 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
- SQL = SQL_TEMPLATES["node_degree"]
- params = {
- "workspace":self.db.workspace,
- "node_id":node_id
- }
+ SQL = SQL_TEMPLATES["node_degree"].format(
+ workspace=self.db.workspace, node_id=node_id
+ )
# print(SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL)
if res:
# print("Node degree",res["degree"])
return res["degree"]
@@ -483,14 +483,12 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据"""
- SQL = SQL_TEMPLATES["get_node"]
- params = {
- "workspace":self.db.workspace,
- "node_id":node_id
- }
+ SQL = SQL_TEMPLATES["get_node"].format(
+ workspace=self.db.workspace, node_id=node_id
+ )
# print(self.db.workspace, node_id)
# print(SQL)
- res = await self.db.query(SQL,params)
+ res = await self.db.query(SQL)
if res:
# print("Get node!",self.db.workspace, node_id,res)
return res
@@ -502,13 +500,12 @@ class OracleGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""根据源和目标节点id获取边"""
- SQL = SQL_TEMPLATES["get_edge"]
- params = {
- "workspace":self.db.workspace,
- "source_node_id":source_node_id,
- "target_node_id":target_node_id
- }
- res = await self.db.query(SQL,params)
+ SQL = SQL_TEMPLATES["get_edge"].format(
+ workspace=self.db.workspace,
+ source_node_id=source_node_id,
+ target_node_id=target_node_id,
+ )
+ res = await self.db.query(SQL)
if res:
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
return res
@@ -519,12 +516,10 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node_edges(self, source_node_id: str):
"""根据节点id获取节点的所有边"""
if await self.has_node(source_node_id):
- SQL = SQL_TEMPLATES["get_node_edges"]
- params = {
- "workspace":self.db.workspace,
- "source_node_id":source_node_id
- }
- res = await self.db.query(sql=SQL, params=params, multirows=True)
+ SQL = SQL_TEMPLATES["get_node_edges"].format(
+ workspace=self.db.workspace, source_node_id=source_node_id
+ )
+ res = await self.db.query(sql=SQL, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
# print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -532,29 +527,7 @@ class OracleGraphStorage(BaseGraphStorage):
else:
# print("Node Edge not exist!",self.db.workspace, source_node_id)
return []
-
- async def get_all_nodes(self, limit: int):
- """查询所有节点"""
- SQL = SQL_TEMPLATES["get_all_nodes"]
- params = {"workspace":self.db.workspace, "limit":str(limit)}
- res = await self.db.query(sql=SQL,params=params, multirows=True)
- if res:
- return res
- async def get_all_edges(self, limit: int):
- """查询所有边"""
- SQL = SQL_TEMPLATES["get_all_edges"]
- params = {"workspace":self.db.workspace, "limit":str(limit)}
- res = await self.db.query(sql=SQL,params=params, multirows=True)
- if res:
- return res
-
- async def get_statistics(self):
- SQL = SQL_TEMPLATES["get_statistics"]
- params = {"workspace":self.db.workspace}
- res = await self.db.query(sql=SQL,params=params, multirows=True)
- if res:
- return res
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
@@ -726,37 +699,5 @@ SQL_TEMPLATES = {
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
- values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
- "get_all_nodes":"""WITH t0 AS (
- SELECT name AS id, entity_type AS label, entity_type, description,
- '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids
- FROM lightrag_graph_nodes
- WHERE workspace = :workspace
- ORDER BY createtime DESC fetch first :limit rows only
- ), t1 AS (
- SELECT t0.id, source_chunk_id
- FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
- ), t2 AS (
- SELECT t1.id, LISTAGG(t2.content, '\n') content
- FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
- GROUP BY t1.id
- )
- SELECT t0.id, label, entity_type, description, t2.content
- FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
- "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
- t1.weight,t1.DESCRIPTION,t2.content
- FROM LIGHTRAG_GRAPH_EDGES t1
- LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
- WHERE t1.workspace=:workspace
- order by t1.CREATETIME DESC
- fetch first :limit rows only""",
- "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
- count(distinct CASE WHEN type='edge' THEN id END) as edges_count
- FROM (
- select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
- MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
- UNION
- select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
- MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
- )""",
+ values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
}
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 2687877a..2d45dbce 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -172,9 +172,7 @@ class LightRAG:
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
- namespace="chunk_entity_relation",
- global_config=asdict(self),
- embedding_func=self.embedding_func,
+ namespace="chunk_entity_relation", global_config=asdict(self)
)
####
# add embedding func by walter over
@@ -226,6 +224,7 @@ class LightRAG:
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
+ update_storage = False
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
@@ -239,6 +238,7 @@ class LightRAG:
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
+ update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
inserting_chunks = {}
@@ -285,7 +285,8 @@ class LightRAG:
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
- await self._insert_done()
+ if update_storage:
+ await self._insert_done()
async def _insert_done(self):
tasks = []
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 6263f153..1acf07e0 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -696,13 +696,17 @@ async def bedrock_embedding(
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
+ device = next(embed_model.parameters()).device
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
- ).input_ids
+ ).input_ids.to(device)
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
- return embeddings.detach().numpy()
+ if embeddings.dtype == torch.bfloat16:
+ return embeddings.detach().to(torch.float32).cpu().numpy()
+ else:
+ return embeddings.detach().cpu().numpy()
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 12f78dcd..1071f8c2 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -662,24 +662,20 @@ async def _find_most_related_text_unit_from_entities(
all_text_units_lookup = {}
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units:
- if c_id in all_text_units_lookup:
- continue
- relation_counts = 0
- if this_edges: # Add check for None edges
+ if c_id not in all_text_units_lookup:
+ all_text_units_lookup[c_id] = {
+ "data": await text_chunks_db.get_by_id(c_id),
+ "order": index,
+ "relation_counts": 0,
+ }
+
+ if this_edges:
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
- relation_counts += 1
-
- chunk_data = await text_chunks_db.get_by_id(c_id)
- if chunk_data is not None and "content" in chunk_data: # Add content check
- all_text_units_lookup[c_id] = {
- "data": chunk_data,
- "order": index,
- "relation_counts": relation_counts,
- }
+ all_text_units_lookup[c_id]["relation_counts"] += 1
# Filter out None values and ensure data has content
all_text_units = [
@@ -714,10 +710,16 @@ async def _find_most_related_edges_from_entities(
all_related_edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
- all_edges = set()
+ all_edges = []
+ seen = set()
+
for this_edges in all_related_edges:
- all_edges.update([tuple(sorted(e)) for e in this_edges])
- all_edges = list(all_edges)
+ for e in this_edges:
+ sorted_edge = tuple(sorted(e))
+ if sorted_edge not in seen:
+ seen.add(sorted_edge)
+ all_edges.append(sorted_edge)
+
all_edges_pack = await asyncio.gather(
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
)
@@ -828,10 +830,16 @@ async def _find_most_related_entities_from_relationships(
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
):
- entity_names = set()
+ entity_names = []
+ seen = set()
+
for e in edge_datas:
- entity_names.add(e["src_id"])
- entity_names.add(e["tgt_id"])
+ if e["src_id"] not in seen:
+ entity_names.append(e["src_id"])
+ seen.add(e["src_id"])
+ if e["tgt_id"] not in seen:
+ entity_names.append(e["tgt_id"])
+ seen.add(e["tgt_id"])
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 9c0e7577..fc739002 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -290,13 +290,19 @@ def process_combine_contexts(hl, ll):
if list_ll:
list_ll = [",".join(item[1:]) for item in list_ll if item]
- combined_sources_set = set(filter(None, list_hl + list_ll))
+ combined_sources = []
+ seen = set()
- combined_sources = [",\t".join(header)]
+ for item in list_hl + list_ll:
+ if item and item not in seen:
+ combined_sources.append(item)
+ seen.add(item)
- for i, item in enumerate(combined_sources_set, start=1):
- combined_sources.append(f"{i},\t{item}")
+ combined_sources_result = [",\t".join(header)]
- combined_sources = "\n".join(combined_sources)
+ for i, item in enumerate(combined_sources, start=1):
+ combined_sources_result.append(f"{i},\t{item}")
- return combined_sources
+ combined_sources_result = "\n".join(combined_sources_result)
+
+ return combined_sources_result
diff --git a/reproduce/Step_1.py b/reproduce/Step_1.py
index 43c44056..e318c145 100644
--- a/reproduce/Step_1.py
+++ b/reproduce/Step_1.py
@@ -24,7 +24,7 @@ def insert_text(rag, file_path):
cls = "agriculture"
-WORKING_DIR = "../{cls}"
+WORKING_DIR = f"../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)