diff --git a/README.md b/README.md index 74c40f15..6d5af135 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ - +

@@ -16,27 +16,34 @@

+

+ + +

This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag). ![LightRAG Diagram](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png) ## 🎉 News -- [x] [2024.11.12]🎯📢You can [use Oracle Database 23ai for all storage types (kv/vector/graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py) now. +- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author! +- [x] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py). - [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity). - [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge. - [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage). - [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`. - [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization. - [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author! -- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉 +- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/yF2MmDJyGJ)! Welcome to join for sharing and discussions! 🎉🎉 - [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! ## Algorithm Flowchart -![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3) - +![LightRAG Indexing Flowchart](https://learnopencv.com/wp-content/uploads/2024/11/LightRAG-VectorDB-Json-KV-Store-Indexing-Flowchart-scaled.jpg) +*Figure 1: LightRAG Indexing Flowchart* +![LightRAG Retrieval and Querying Flowchart](https://learnopencv.com/wp-content/uploads/2024/11/LightRAG-Querying-Flowchart-Dual-Level-Retrieval-Generation-Knowledge-Graphs-scaled.jpg) +*Figure 2: LightRAG Retrieval and Querying Flowchart* ## Install diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index 3b2cafc6..c06b8a83 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -162,12 +162,12 @@ class Response(BaseModel): # API routes -rag = None # 定义为全局对象 +rag = None @asynccontextmanager async def lifespan(app: FastAPI): global rag - rag = await init() # 在应用启动时初始化 `rag` + rag = await init() print("done!") yield diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py index e29a6a9d..6a3fafc1 100644 --- a/examples/lightrag_azure_openai_demo.py +++ b/examples/lightrag_azure_openai_demo.py @@ -4,8 +4,8 @@ from lightrag import LightRAG, QueryParam from lightrag.utils import EmbeddingFunc import numpy as np from dotenv import load_dotenv -import aiohttp import logging +from openai import AzureOpenAI logging.basicConfig(level=logging.INFO) @@ -32,11 +32,11 @@ os.mkdir(WORKING_DIR) async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - headers = { - "Content-Type": "application/json", - "api-key": AZURE_OPENAI_API_KEY, - } - endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}" + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) messages = [] if system_prompt: @@ -45,41 +45,26 @@ async def llm_model_func( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - payload = { - "messages": messages, - "temperature": kwargs.get("temperature", 0), - "top_p": kwargs.get("top_p", 1), - "n": kwargs.get("n", 1), - } - - async with aiohttp.ClientSession() as session: - async with session.post(endpoint, headers=headers, json=payload) as response: - if response.status != 200: - raise ValueError( - f"Request failed with status {response.status}: {await response.text()}" - ) - result = await response.json() - return result["choices"][0]["message"]["content"] + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name". + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + return chat_completion.choices[0].message.content async def embedding_func(texts: list[str]) -> np.ndarray: - headers = { - "Content-Type": "application/json", - "api-key": AZURE_OPENAI_API_KEY, - } - endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}" + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_EMBEDDING_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts) - payload = {"input": texts} - - async with aiohttp.ClientSession() as session: - async with session.post(endpoint, headers=headers, json=payload) as response: - if response.status != 200: - raise ValueError( - f"Request failed with status {response.status}: {await response.text()}" - ) - result = await response.json() - embeddings = [item["embedding"] for item in result["data"]] - return np.array(embeddings) + embeddings = [item.embedding for item in embedding.data] + return np.array(embeddings) async def test_funcs(): diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 6d9003ff..c8b61765 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.0.0" +__version__ = "1.0.1" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index e6b33a9b..84efae81 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -86,9 +86,6 @@ class Neo4JStorage(BaseGraphStorage): ) return single_result["edgeExists"] - def close(self): - self._driver.close() - async def get_node(self, node_id: str) -> Union[dict, None]: async with self._driver.session() as session: entity_name_label = node_id.strip('"') @@ -214,6 +211,7 @@ class Neo4JStorage(BaseGraphStorage): neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, ) ), ) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 2e394b8a..2cfbd249 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,7 +114,7 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]: + async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler @@ -173,10 +173,11 @@ class OracleKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"workspace":self.db.workspace, "id":id} + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( + workspace=self.db.workspace, id=id + ) # print("get_by_id:"+SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL) if res: data = res # {"data":res} # print (data) @@ -187,11 +188,11 @@ class OracleKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids])) - params = {"workspace":self.db.workspace} - #print("get_by_ids:"+SQL) - #print(params) - res = await self.db.query(SQL,params, multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) + ) + # print("get_by_ids:"+SQL) + res = await self.db.query(SQL, multirows=True) if res: data = res # [{"data":i} for i in res] # print(data) @@ -201,16 +202,12 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], - ids=",".join([f"'{id}'" for id in keys])) - params = {"workspace":self.db.workspace} - try: - await self.db.query(SQL, params) - except Exception as e: - logger.error(f"Oracle database error: {e}") - print(SQL) - print(params) - res = await self.db.query(SQL, params,multirows=True) + SQL = SQL_TEMPLATES["filter_keys"].format( + table_name=N_T[self.namespace], + workspace=self.db.workspace, + ids=",".join([f"'{k}'" for k in keys]), + ) + res = await self.db.query(SQL, multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -247,29 +244,27 @@ class OracleKVStorage(BaseKVStorage): d["__vector__"] = embeddings[i] # print(list_data) for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"] - data = {"check_id":item["__id__"], - "id":item["__id__"], - "content":item["content"], - "workspace":self.db.workspace, - "tokens":item["tokens"], - "chunk_order_index":item["chunk_order_index"], - "full_doc_id":item["full_doc_id"], - "content_vector":item["__vector__"] - } + merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) + + values = [ + item["__id__"], + item["content"], + self.db.workspace, + item["tokens"], + item["chunk_order_index"], + item["full_doc_id"], + item["__vector__"], + ] # print(merge_sql) await self.db.execute(merge_sql, data) if self.namespace == "full_docs": for k, v in self._data.items(): # values.clear() - merge_sql = SQL_TEMPLATES["merge_doc_full"] - data = { - "check_id":k, - "id":k, - "content":v["content"], - "workspace":self.db.workspace - } + merge_sql = SQL_TEMPLATES["merge_doc_full"].format( + check_id=k, + ) + values = [k, self._data[k]["content"], self.db.workspace] # print(merge_sql) await self.db.execute(merge_sql, data) return left_data @@ -301,17 +296,18 @@ class OracleVectorDBStorage(BaseVectorStorage): # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]" + embedding_string = ", ".join(map(str, embedding.tolist())) - SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) - params = { - "embedding_string": embedding_string, - "workspace": self.db.workspace, - "top_k": top_k, - "better_than_threshold": self.cosine_better_than_threshold, - } + SQL = SQL_TEMPLATES[self.namespace].format( + embedding_string=embedding_string, + dimension=dimension, + dtype=dtype, + workspace=self.db.workspace, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) # print(SQL) - results = await self.db.query(SQL,params=params, multirows=True) + results = await self.db.query(SQL, multirows=True) # print("vector search result:",results) return results @@ -346,18 +342,22 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"] - data = { - "workspace":self.db.workspace, - "name":entity_name, - "entity_type":entity_type, - "description":description, - "source_chunk_id":source_id, - "content":content, - "content_vector":content_vector - } + merge_sql = SQL_TEMPLATES["merge_node"].format( + workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id + ) # print(merge_sql) - await self.db.execute(merge_sql,data) + await self.db.execute( + merge_sql, + [ + self.db.workspace, + entity_name, + entity_type, + description, + source_id, + content, + content_vector, + ], + ) # self._graph.add_node(node_id, **node_data) async def upsert_edge( @@ -371,8 +371,6 @@ class OracleGraphStorage(BaseGraphStorage): keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] - logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}") - content = keywords + source_name + target_name + description contents = [content] batches = [ @@ -384,20 +382,27 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"] - data = { - "workspace":self.db.workspace, - "source_name":source_name, - "target_name":target_name, - "weight":weight, - "keywords":keywords, - "description":description, - "source_chunk_id":source_chunk_id, - "content":content, - "content_vector":content_vector - } + merge_sql = SQL_TEMPLATES["merge_edge"].format( + workspace=self.db.workspace, + source_name=source_name, + target_name=target_name, + source_chunk_id=source_chunk_id, + ) # print(merge_sql) - await self.db.execute(merge_sql,data) + await self.db.execute( + merge_sql, + [ + self.db.workspace, + source_name, + target_name, + weight, + keywords, + description, + source_chunk_id, + content, + content_vector, + ], + ) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: @@ -427,14 +432,12 @@ class OracleGraphStorage(BaseGraphStorage): #################### query method ################# async def has_node(self, node_id: str) -> bool: """根据节点id检查节点是否存在""" - SQL = SQL_TEMPLATES["has_node"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + SQL = SQL_TEMPLATES["has_node"].format( + workspace=self.db.workspace, node_id=node_id + ) # print(SQL) # print(self.db.workspace, node_id) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL) if res: # print("Node exist!",res) return True @@ -444,14 +447,13 @@ class OracleGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """根据源和目标节点id检查边是否存在""" - SQL = SQL_TEMPLATES["has_edge"] - params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id, - "target_node_id":target_node_id - } + SQL = SQL_TEMPLATES["has_edge"].format( + workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id, + ) # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL) if res: # print("Edge exist!",res) return True @@ -461,13 +463,11 @@ class OracleGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """根据节点id获取节点的度""" - SQL = SQL_TEMPLATES["node_degree"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + SQL = SQL_TEMPLATES["node_degree"].format( + workspace=self.db.workspace, node_id=node_id + ) # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL) if res: # print("Node degree",res["degree"]) return res["degree"] @@ -483,14 +483,12 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" - SQL = SQL_TEMPLATES["get_node"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + SQL = SQL_TEMPLATES["get_node"].format( + workspace=self.db.workspace, node_id=node_id + ) # print(self.db.workspace, node_id) # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL) if res: # print("Get node!",self.db.workspace, node_id,res) return res @@ -502,13 +500,12 @@ class OracleGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """根据源和目标节点id获取边""" - SQL = SQL_TEMPLATES["get_edge"] - params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id, - "target_node_id":target_node_id - } - res = await self.db.query(SQL,params) + SQL = SQL_TEMPLATES["get_edge"].format( + workspace=self.db.workspace, + source_node_id=source_node_id, + target_node_id=target_node_id, + ) + res = await self.db.query(SQL) if res: # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res @@ -519,12 +516,10 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node_edges(self, source_node_id: str): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): - SQL = SQL_TEMPLATES["get_node_edges"] - params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id - } - res = await self.db.query(sql=SQL, params=params, multirows=True) + SQL = SQL_TEMPLATES["get_node_edges"].format( + workspace=self.db.workspace, source_node_id=source_node_id + ) + res = await self.db.query(sql=SQL, multirows=True) if res: data = [(i["source_name"], i["target_name"]) for i in res] # print("Get node edge!",self.db.workspace, source_node_id,data) @@ -532,29 +527,7 @@ class OracleGraphStorage(BaseGraphStorage): else: # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] - - async def get_all_nodes(self, limit: int): - """查询所有节点""" - SQL = SQL_TEMPLATES["get_all_nodes"] - params = {"workspace":self.db.workspace, "limit":str(limit)} - res = await self.db.query(sql=SQL,params=params, multirows=True) - if res: - return res - async def get_all_edges(self, limit: int): - """查询所有边""" - SQL = SQL_TEMPLATES["get_all_edges"] - params = {"workspace":self.db.workspace, "limit":str(limit)} - res = await self.db.query(sql=SQL,params=params, multirows=True) - if res: - return res - - async def get_statistics(self): - SQL = SQL_TEMPLATES["get_statistics"] - params = {"workspace":self.db.workspace} - res = await self.db.query(sql=SQL,params=params, multirows=True) - if res: - return res N_T = { "full_docs": "LIGHTRAG_DOC_FULL", @@ -726,37 +699,5 @@ SQL_TEMPLATES = { ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, - "get_all_nodes":"""WITH t0 AS ( - SELECT name AS id, entity_type AS label, entity_type, description, - '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids - FROM lightrag_graph_nodes - WHERE workspace = :workspace - ORDER BY createtime DESC fetch first :limit rows only - ), t1 AS ( - SELECT t0.id, source_chunk_id - FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) ) - ), t2 AS ( - SELECT t1.id, LISTAGG(t2.content, '\n') content - FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id - GROUP BY t1.id - ) - SELECT t0.id, label, entity_type, description, t2.content - FROM t0 LEFT JOIN t2 ON t0.id = t2.id""", - "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, - t1.weight,t1.DESCRIPTION,t2.content - FROM LIGHTRAG_GRAPH_EDGES t1 - LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id - WHERE t1.workspace=:workspace - order by t1.CREATETIME DESC - fetch first :limit rows only""", - "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count, - count(distinct CASE WHEN type='edge' THEN id END) as edges_count - FROM ( - select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph - MATCH (a) WHERE a.workspace=:workspace columns(a.name as id)) - UNION - select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph - MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) - )""", + values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 2687877a..2d45dbce 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -172,9 +172,7 @@ class LightRAG: embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace="chunk_entity_relation", - global_config=asdict(self), - embedding_func=self.embedding_func, + namespace="chunk_entity_relation", global_config=asdict(self) ) #### # add embedding func by walter over @@ -226,6 +224,7 @@ class LightRAG: return loop.run_until_complete(self.ainsert(string_or_strings)) async def ainsert(self, string_or_strings): + update_storage = False try: if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] @@ -239,6 +238,7 @@ class LightRAG: if not len(new_docs): logger.warning("All docs are already in the storage") return + update_storage = True logger.info(f"[New Docs] inserting {len(new_docs)} docs") inserting_chunks = {} @@ -285,7 +285,8 @@ class LightRAG: await self.full_docs.upsert(new_docs) await self.text_chunks.upsert(inserting_chunks) finally: - await self._insert_done() + if update_storage: + await self._insert_done() async def _insert_done(self): tasks = [] diff --git a/lightrag/llm.py b/lightrag/llm.py index 6263f153..1acf07e0 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -696,13 +696,17 @@ async def bedrock_embedding( async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: + device = next(embed_model.parameters()).device input_ids = tokenizer( texts, return_tensors="pt", padding=True, truncation=True - ).input_ids + ).input_ids.to(device) with torch.no_grad(): outputs = embed_model(input_ids) embeddings = outputs.last_hidden_state.mean(dim=1) - return embeddings.detach().numpy() + if embeddings.dtype == torch.bfloat16: + return embeddings.detach().to(torch.float32).cpu().numpy() + else: + return embeddings.detach().cpu().numpy() async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: diff --git a/lightrag/operate.py b/lightrag/operate.py index 12f78dcd..1071f8c2 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -662,24 +662,20 @@ async def _find_most_related_text_unit_from_entities( all_text_units_lookup = {} for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): for c_id in this_text_units: - if c_id in all_text_units_lookup: - continue - relation_counts = 0 - if this_edges: # Add check for None edges + if c_id not in all_text_units_lookup: + all_text_units_lookup[c_id] = { + "data": await text_chunks_db.get_by_id(c_id), + "order": index, + "relation_counts": 0, + } + + if this_edges: for e in this_edges: if ( e[1] in all_one_hop_text_units_lookup and c_id in all_one_hop_text_units_lookup[e[1]] ): - relation_counts += 1 - - chunk_data = await text_chunks_db.get_by_id(c_id) - if chunk_data is not None and "content" in chunk_data: # Add content check - all_text_units_lookup[c_id] = { - "data": chunk_data, - "order": index, - "relation_counts": relation_counts, - } + all_text_units_lookup[c_id]["relation_counts"] += 1 # Filter out None values and ensure data has content all_text_units = [ @@ -714,10 +710,16 @@ async def _find_most_related_edges_from_entities( all_related_edges = await asyncio.gather( *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] ) - all_edges = set() + all_edges = [] + seen = set() + for this_edges in all_related_edges: - all_edges.update([tuple(sorted(e)) for e in this_edges]) - all_edges = list(all_edges) + for e in this_edges: + sorted_edge = tuple(sorted(e)) + if sorted_edge not in seen: + seen.add(sorted_edge) + all_edges.append(sorted_edge) + all_edges_pack = await asyncio.gather( *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges] ) @@ -828,10 +830,16 @@ async def _find_most_related_entities_from_relationships( query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): - entity_names = set() + entity_names = [] + seen = set() + for e in edge_datas: - entity_names.add(e["src_id"]) - entity_names.add(e["tgt_id"]) + if e["src_id"] not in seen: + entity_names.append(e["src_id"]) + seen.add(e["src_id"]) + if e["tgt_id"] not in seen: + entity_names.append(e["tgt_id"]) + seen.add(e["tgt_id"]) node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] diff --git a/lightrag/utils.py b/lightrag/utils.py index 9c0e7577..fc739002 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -290,13 +290,19 @@ def process_combine_contexts(hl, ll): if list_ll: list_ll = [",".join(item[1:]) for item in list_ll if item] - combined_sources_set = set(filter(None, list_hl + list_ll)) + combined_sources = [] + seen = set() - combined_sources = [",\t".join(header)] + for item in list_hl + list_ll: + if item and item not in seen: + combined_sources.append(item) + seen.add(item) - for i, item in enumerate(combined_sources_set, start=1): - combined_sources.append(f"{i},\t{item}") + combined_sources_result = [",\t".join(header)] - combined_sources = "\n".join(combined_sources) + for i, item in enumerate(combined_sources, start=1): + combined_sources_result.append(f"{i},\t{item}") - return combined_sources + combined_sources_result = "\n".join(combined_sources_result) + + return combined_sources_result diff --git a/reproduce/Step_1.py b/reproduce/Step_1.py index 43c44056..e318c145 100644 --- a/reproduce/Step_1.py +++ b/reproduce/Step_1.py @@ -24,7 +24,7 @@ def insert_text(rag, file_path): cls = "agriculture" -WORKING_DIR = "../{cls}" +WORKING_DIR = f"../{cls}" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR)