inference running locally. use neo4j next

This commit is contained in:
Ken Wiltshire
2024-10-27 15:37:41 -04:00
parent cc45ea7310
commit 01b7df7afa
8 changed files with 98 additions and 76 deletions

View File

@@ -1,5 +1,27 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
print ("init package vars here. ......")
from .neo4j import GraphStorage as Neo4JStorage
# import sys
# import importlib
# # Specify the path to the directory containing the module
# # Add the directory to the system path
# module_dir = '/Users/kenwiltshire/documents/dev/LightRag/lightrag/kg'
# sys.path.append(module_dir)
# # Specify the module name
# module_name = 'neo4j'
# # Import the module
# spec = importlib.util.spec_from_file_location(module_name, f'{module_dir}/{module_name}.py')
# Neo4JStorage = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(Neo4JStorage)
# Relative imports are still possible by adding a leading period to the module name when using the from ... import form:
# # Import names from pkg.string
# from .string import name1, name2
# # Import pkg.string
# from . import string
__version__ = "0.0.7"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -3,11 +3,15 @@ import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
from .utils import load_json, logger, write_json
# import package.common.utils as utils
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
)
@@ -22,10 +26,10 @@ PASSWORD = "your_password"
@dataclass
class GraphStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
# def load_nx_graph(file_name) -> nx.Graph:
# if os.path.exists(file_name):
# return nx.read_graphml(file_name)
# return None
def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph()
@@ -102,7 +106,7 @@ class GraphStorage(BaseGraphStorage):
result = session.run(
"""MATCH (n1:{node_label1})-[r]-(n2:{node_label2})
RETURN count(r) AS degree"""
.format(node_label1=node_label1, node_label2=node_label2)
.format(entity_name__label_source=entity_name__label_source, entity_name_label_target=entity_name_label_target)
)
record = result.single()
return record["degree"]
@@ -263,7 +267,7 @@ class GraphStorage(BaseGraphStorage):
with self._driver.session() as session:
#Define the Cypher query
options = self.global_config["node2vec_params"]
query = f"""CALL gds.node2vec.stream('myGraph', {**options})
query = f"""CALL gds.node2vec.stream('myGraph', {options}) # **options
YIELD nodeId, embedding
RETURN nodeId, embedding"""
# Run the query and process the results