cleanup code
This commit is contained in:
@@ -1,5 +1,3 @@
|
|||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
lightrag:
|
lightrag:
|
||||||
build: .
|
build: .
|
||||||
|
@@ -98,7 +98,6 @@ async def init():
|
|||||||
|
|
||||||
# Initialize LightRAG
|
# Initialize LightRAG
|
||||||
# We use Oracle DB as the KV/vector/graph storage
|
# We use Oracle DB as the KV/vector/graph storage
|
||||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
enable_llm_cache=False,
|
enable_llm_cache=False,
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
|
@@ -1,9 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import inspect
|
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.llm import openai_complete, openai_embed
|
from lightrag.llm import openai_complete, openai_embed
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from lightrag.lightrag import always_get_an_event_loop
|
|
||||||
from lightrag import QueryParam
|
from lightrag import QueryParam
|
||||||
|
|
||||||
# WorkingDir
|
# WorkingDir
|
||||||
@@ -48,8 +46,3 @@ async def print_stream(stream):
|
|||||||
print(chunk, end="", flush=True)
|
print(chunk, end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
loop = always_get_an_event_loop()
|
|
||||||
if inspect.isasyncgen(resp):
|
|
||||||
loop.run_until_complete(print_stream(resp))
|
|
||||||
else:
|
|
||||||
print(resp)
|
|
||||||
|
@@ -63,7 +63,6 @@ async def main():
|
|||||||
|
|
||||||
# Initialize LightRAG
|
# Initialize LightRAG
|
||||||
# We use TiDB DB as the KV/vector
|
# We use TiDB DB as the KV/vector
|
||||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
enable_llm_cache=False,
|
enable_llm_cache=False,
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
|
@@ -32,8 +32,10 @@ from .operate import (
|
|||||||
from .prompt import GRAPH_FIELD_SEP
|
from .prompt import GRAPH_FIELD_SEP
|
||||||
from .utils import (
|
from .utils import (
|
||||||
EmbeddingFunc,
|
EmbeddingFunc,
|
||||||
|
always_get_an_event_loop,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
convert_response_to_json,
|
convert_response_to_json,
|
||||||
|
lazy_external_import,
|
||||||
limit_async_func_call,
|
limit_async_func_call,
|
||||||
logger,
|
logger,
|
||||||
set_logger,
|
set_logger,
|
||||||
@@ -182,48 +184,9 @@ STORAGES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
|
||||||
"""Lazily import a class from an external module based on the package of the caller."""
|
|
||||||
# Get the caller's module and package
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
caller_frame = inspect.currentframe().f_back
|
|
||||||
module = inspect.getmodule(caller_frame)
|
|
||||||
package = module.__package__ if module else None
|
|
||||||
|
|
||||||
def import_class(*args: Any, **kwargs: Any):
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
module = importlib.import_module(module_name, package=package)
|
|
||||||
cls = getattr(module, class_name)
|
|
||||||
return cls(*args, **kwargs)
|
|
||||||
|
|
||||||
return import_class
|
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
||||||
"""
|
|
||||||
Ensure that there is always an event loop available.
|
|
||||||
|
|
||||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
|
||||||
it creates a new event loop and sets it as the current event loop.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to get the current event loop
|
|
||||||
current_loop = asyncio.get_event_loop()
|
|
||||||
if current_loop.is_closed():
|
|
||||||
raise RuntimeError("Event loop is closed.")
|
|
||||||
return current_loop
|
|
||||||
|
|
||||||
except RuntimeError:
|
|
||||||
# If no event loop exists or it is closed, create a new one
|
|
||||||
logger.info("Creating a new event loop in main thread.")
|
|
||||||
new_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(new_loop)
|
|
||||||
return new_loop
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@@ -428,46 +391,6 @@ class LightRAG:
|
|||||||
The default function is :func:`.utils.convert_response_to_json`.
|
The default function is :func:`.utils.convert_response_to_json`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def verify_storage_implementation(
|
|
||||||
self, storage_type: str, storage_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Verify if storage implementation is compatible with specified storage type
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
|
||||||
storage_name: Storage implementation name
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If storage implementation is incompatible or missing required methods
|
|
||||||
"""
|
|
||||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
|
||||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
||||||
|
|
||||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
|
||||||
if storage_name not in storage_info["implementations"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
|
||||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
|
||||||
"""Check if all required environment variables for storage implementation exist
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_name: Storage implementation name
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If required environment variables are missing
|
|
||||||
"""
|
|
||||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
||||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
||||||
|
|
||||||
if missing_vars:
|
|
||||||
raise ValueError(
|
|
||||||
f"Storage implementation '{storage_name}' requires the following "
|
|
||||||
f"environment variables: {', '.join(missing_vars)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
os.makedirs(self.log_dir, exist_ok=True)
|
os.makedirs(self.log_dir, exist_ok=True)
|
||||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
log_file = os.path.join(self.log_dir, "lightrag.log")
|
||||||
@@ -1681,3 +1604,43 @@ class LightRAG:
|
|||||||
result["vector_data"] = vector_data[0] if vector_data else None
|
result["vector_data"] = vector_data[0] if vector_data else None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def verify_storage_implementation(
|
||||||
|
self, storage_type: str, storage_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Verify if storage implementation is compatible with specified storage type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||||
|
storage_name: Storage implementation name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If storage implementation is incompatible or missing required methods
|
||||||
|
"""
|
||||||
|
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||||
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||||
|
|
||||||
|
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||||
|
if storage_name not in storage_info["implementations"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||||
|
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||||
|
"""Check if all required environment variables for storage implementation exist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_name: Storage implementation name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required environment variables are missing
|
||||||
|
"""
|
||||||
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||||
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||||
|
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Storage implementation '{storage_name}' requires the following "
|
||||||
|
f"environment variables: {', '.join(missing_vars)}"
|
||||||
|
)
|
@@ -713,3 +713,47 @@ def get_conversation_turns(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(formatted_turns)
|
return "\n".join(formatted_turns)
|
||||||
|
|
||||||
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
|
"""
|
||||||
|
Ensure that there is always an event loop available.
|
||||||
|
|
||||||
|
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||||
|
it creates a new event loop and sets it as the current event loop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to get the current event loop
|
||||||
|
current_loop = asyncio.get_event_loop()
|
||||||
|
if current_loop.is_closed():
|
||||||
|
raise RuntimeError("Event loop is closed.")
|
||||||
|
return current_loop
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
# If no event loop exists or it is closed, create a new one
|
||||||
|
logger.info("Creating a new event loop in main thread.")
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
return new_loop
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||||
|
"""Lazily import a class from an external module based on the package of the caller."""
|
||||||
|
# Get the caller's module and package
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
caller_frame = inspect.currentframe().f_back
|
||||||
|
module = inspect.getmodule(caller_frame)
|
||||||
|
package = module.__package__ if module else None
|
||||||
|
|
||||||
|
def import_class(*args: Any, **kwargs: Any):
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
module = importlib.import_module(module_name, package=package)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
|
return import_class
|
||||||
|
|
@@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.utils import always_get_an_event_loop
|
||||||
|
|
||||||
|
|
||||||
def extract_queries(file_path):
|
def extract_queries(file_path):
|
||||||
@@ -23,14 +23,6 @@ async def process_query(query_text, rag_instance, query_param):
|
|||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
return loop
|
|
||||||
|
|
||||||
|
|
||||||
def run_queries_and_save_to_json(
|
def run_queries_and_save_to_json(
|
||||||
queries, rag_instance, query_param, output_file, error_file
|
queries, rag_instance, query_param, output_file, error_file
|
||||||
|
@@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@@ -55,13 +54,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
return loop
|
|
||||||
|
|
||||||
|
|
||||||
def run_queries_and_save_to_json(
|
def run_queries_and_save_to_json(
|
||||||
|
Reference in New Issue
Block a user