Manually reformatted files

This commit is contained in:
Sanketh Kumar
2024-10-25 13:32:25 +05:30
parent 8fbbf70a83
commit 5e3ab98d83
11 changed files with 175 additions and 95 deletions

View File

@@ -15,7 +15,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
@@ -27,4 +27,4 @@ jobs:
pip install pre-commit
- name: Run pre-commit
run: pre-commit run --all-files
run: pre-commit run --all-files

2
.gitignore vendored
View File

@@ -4,4 +4,4 @@ dickens/
book.txt
lightrag-dev/
.idea/
dist/
dist/

View File

@@ -58,8 +58,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
@@ -157,7 +157,7 @@ rag = LightRAG(
<details>
<summary> Using Ollama Models </summary>
* If you want to use Ollama models, you only need to set LightRAG as follows:
```python
@@ -328,8 +328,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -382,7 +382,7 @@ def main():
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()

View File

@@ -3,7 +3,7 @@ from pyvis.network import Network
import random
# Load the GraphML file
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
# Create a Pyvis network
net = Network(notebook=True)
@@ -13,7 +13,7 @@ net.from_nx(G)
# Add colors to nodes
for node in net.nodes:
node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
# Save and display the network
net.show('knowledge_graph.html')
net.show("knowledge_graph.html")

View File

@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
json_data = xml_to_json(xml_path)
if json_data:
with open(output_path, 'w', encoding='utf-8') as f:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
print("Failed to create JSON data")
return None
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size]
batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
def main():
# Paths
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
return
# Load nodes and edges
nodes = json_data.get('nodes', [])
edges = json_data.get('edges', [])
nodes = json_data.get("nodes", [])
edges = json_data.get("edges", [])
# Neo4j queries
create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -100,19 +103,24 @@ def main():
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
session.execute_write(
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
)
# Insert edges in batches
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
session.execute_write(
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
)
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()
if __name__ == "__main__":
main()

View File

@@ -52,6 +52,7 @@ async def test_funcs():
# asyncio.run(test_funcs())
async def main():
try:
embedding_dimension = await get_embedding_dim()
@@ -61,35 +62,47 @@ async def main():
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
rag.query(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
)
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
texts,
model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("SILICONFLOW_API_KEY"),
max_token_size=512
max_token_size=512,
)

View File

@@ -27,11 +27,12 @@ rag = LightRAG(
# Read all .txt files from the TEXT_FILES_DIR directory
texts = []
for filename in os.listdir(TEXT_FILES_DIR):
if filename.endswith('.txt'):
if filename.endswith(".txt"):
file_path = os.path.join(TEXT_FILES_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
texts.append(file.read())
# Batch insert texts into LightRAG with a retry mechanism
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
rag.insert(texts)
return
except Exception as e:
print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
print(
f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
)
time.sleep(delay)
raise RuntimeError("Failed to insert texts after multiple retries.")
insert_texts_with_retry(rag, texts)
# Perform different types of queries and handle potential errors
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
except Exception as e:
print(f"Error performing naive search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
except Exception as e:
print(f"Error performing local search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
except Exception as e:
print(f"Error performing global search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
except Exception as e:
print(f"Error performing hybrid search: {e}")
# Function to clear VRAM resources
def clear_vram():
os.system("sudo nvidia-smi --gpu-reset")
# Regularly clear VRAM to prevent overflow
clear_vram_interval = 3600 # Clear once every hour
start_time = time.time()

View File

@@ -7,7 +7,13 @@ import aiohttp
import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)
import base64
import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
)
return response.choices[0].message.content
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_complete_if_cache(model,
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs):
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
)
return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
@@ -328,8 +344,9 @@ async def gpt_4o_mini_complete(
**kwargs,
)
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
@@ -339,6 +356,7 @@ async def azure_openai_complete(
**kwargs,
)
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -418,9 +436,11 @@ async def azure_openai_embedding(
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
@@ -440,35 +460,28 @@ async def siliconcloud_embedding(
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith('Bearer '):
api_key = 'Bearer ' + api_key
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
headers = {
"Authorization": api_key,
"Content-Type": "application/json"
}
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {
"model": model,
"input": truncate_texts,
"encoding_format": "base64"
}
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if 'code' in content:
if "code" in content:
raise ValueError(content)
base64_strings = [item['embedding'] for item in content['data']]
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack('<' + 'f' * n, decode_bytes)
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
@@ -563,6 +576,7 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
return embed_text
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +594,20 @@ class Model(BaseModel):
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
gen_func: Callable[[Any], str] = Field(
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: Dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
class Config:
arbitrary_types_allowed = True
class MultiModel():
class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
@@ -611,26 +631,31 @@ class MultiModel():
)
```
"""
def __init__(self, models: List[Model]):
self._models = models
self._current_model = 0
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
self,
prompt, system_prompt=None, history_messages=[], **kwargs
self, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("model", None) # stop from overwriting the custom model name
next_model = self._next_model()
args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
return await next_model.gen_func(
**args
args = dict(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
**next_model.kwargs,
)
return await next_model.gen_func(**args)
if __name__ == "__main__":
import asyncio

View File

@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
data = {
"nodes": [],
"edges": []
}
data = {"nodes": [], "edges": []}
# Use namespace
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
for node in root.findall('.//node', namespace):
for node in root.findall(".//node", namespace):
node_data = {
"id": node.get('id').strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
"id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
for edge in root.findall('.//edge', namespace):
for edge in root.findall(".//edge", namespace):
edge_data = {
"source": edge.get('source').strip('"'),
"target": edge.get('target').strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
}
data["edges"].append(edge_data)

View File

@@ -1,15 +1,15 @@
accelerate
aioboto3
aiohttp
graspologic
hnswlib
nano-vectordb
networkx
ollama
openai
pyvis
tenacity
tiktoken
torch
transformers
xxhash
pyvis
aiohttp