inference running locally. use neo4j next

This commit is contained in:
Ken Wiltshire
2024-10-27 15:37:41 -04:00
parent 474fe59a79
commit 0796d3d8e0
8 changed files with 98 additions and 76 deletions

3
.gitignore vendored
View File

@@ -4,4 +4,5 @@ dickens/
book.txt book.txt
lightrag-dev/ lightrag-dev/
.idea/ .idea/
dist/ dist/
env/

View File

@@ -1,5 +1,27 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam print ("init package vars here. ......")
from .neo4j import GraphStorage as Neo4JStorage
# import sys
# import importlib
# # Specify the path to the directory containing the module
# # Add the directory to the system path
# module_dir = '/Users/kenwiltshire/documents/dev/LightRag/lightrag/kg'
# sys.path.append(module_dir)
# # Specify the module name
# module_name = 'neo4j'
# # Import the module
# spec = importlib.util.spec_from_file_location(module_name, f'{module_dir}/{module_name}.py')
# Neo4JStorage = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(Neo4JStorage)
# Relative imports are still possible by adding a leading period to the module name when using the from ... import form:
# # Import names from pkg.string
# from .string import name1, name2
# # Import pkg.string
# from . import string
__version__ = "0.0.7"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -3,11 +3,15 @@ import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Any, Union, cast
import networkx as nx
import numpy as np import numpy as np
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
from .utils import load_json, logger, write_json
# import package.common.utils as utils
from lightrag.utils import load_json, logger, write_json
from ..base import ( from ..base import (
BaseGraphStorage BaseGraphStorage
) )
@@ -22,10 +26,10 @@ PASSWORD = "your_password"
@dataclass @dataclass
class GraphStorage(BaseGraphStorage): class GraphStorage(BaseGraphStorage):
@staticmethod @staticmethod
def load_nx_graph(file_name) -> nx.Graph: # def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name): # if os.path.exists(file_name):
return nx.read_graphml(file_name) # return nx.read_graphml(file_name)
return None # return None
def __post_init__(self): def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph() # self._graph = preloaded_graph or nx.Graph()
@@ -102,7 +106,7 @@ class GraphStorage(BaseGraphStorage):
result = session.run( result = session.run(
"""MATCH (n1:{node_label1})-[r]-(n2:{node_label2}) """MATCH (n1:{node_label1})-[r]-(n2:{node_label2})
RETURN count(r) AS degree""" RETURN count(r) AS degree"""
.format(node_label1=node_label1, node_label2=node_label2) .format(entity_name__label_source=entity_name__label_source, entity_name_label_target=entity_name_label_target)
) )
record = result.single() record = result.single()
return record["degree"] return record["degree"]
@@ -263,7 +267,7 @@ class GraphStorage(BaseGraphStorage):
with self._driver.session() as session: with self._driver.session() as session:
#Define the Cypher query #Define the Cypher query
options = self.global_config["node2vec_params"] options = self.global_config["node2vec_params"]
query = f"""CALL gds.node2vec.stream('myGraph', {**options}) query = f"""CALL gds.node2vec.stream('myGraph', {options}) # **options
YIELD nodeId, embedding YIELD nodeId, embedding
RETURN nodeId, embedding""" RETURN nodeId, embedding"""
# Run the query and process the results # Run the query and process the results

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
import importlib
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@@ -23,6 +24,11 @@ from .storage import (
NanoVectorDBStorage, NanoVectorDBStorage,
NetworkXStorage, NetworkXStorage,
) )
from .kg.neo4j import (
GraphStorage as Neo4JStorage
)
from .utils import ( from .utils import (
EmbeddingFunc, EmbeddingFunc,
compute_mdhash_id, compute_mdhash_id,
@@ -93,7 +99,14 @@ class LightRAG:
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
# module = importlib.import_module('kg.neo4j')
# Neo4JStorage = getattr(module, 'GraphStorage')
if True==False:
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
else:
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
enable_llm_cache: bool = True enable_llm_cache: bool = True
# extension # extension

View File

@@ -72,7 +72,9 @@ async def openai_complete_if_cache(
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), #kw_
wait=wait_exponential(multiplier=1, min=4, max=60),
# wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
) )
async def azure_openai_complete_if_cache(model, async def azure_openai_complete_if_cache(model,

View File

@@ -6,8 +6,6 @@ from typing import Any, Union, cast
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
from kg.neo4j import GraphStorage
from .utils import load_json, logger, write_json from .utils import load_json, logger, write_json
from .base import ( from .base import (
@@ -99,66 +97,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data) results = self._client.upsert(datas=list_data)
return results return results
@dataclass
class PineConeVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
import os
from pinecone import Pinecone
pc = Pinecone() #api_key=os.environ.get('PINECONE_API_KEY'))
# From here on, everything is identical to the REST-based SDK.
self._client = pc.Index(host=self._client_pinecone_host)#'my-index-8833ca1.svc.us-east1-gcp.pinecone.io')
self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# self._client.upsert(vectors=[]) pinecone
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5): async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
# self._client.query(vector=[...], top_key=10) pinecone
results = self._client.query( results = self._client.query(
vector=embedding, query=embedding,
top_k=top_k, top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold, ??? better_than_threshold=self.cosine_better_than_threshold,
) )
results = [ results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
@@ -166,8 +112,7 @@ class PineConeVectorDBStorage(BaseVectorStorage):
return results return results
async def index_done_callback(self): async def index_done_callback(self):
print("self._client.save()") self._client.save()
# self._client.save()
@dataclass @dataclass
@@ -298,5 +243,3 @@ class NetworkXStorage(BaseGraphStorage):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids

View File

@@ -12,4 +12,5 @@ torch
transformers transformers
xxhash xxhash
pyvis pyvis
aiohttp aiohttp
neo4j

36
test.py Normal file
View File

@@ -0,0 +1,36 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))