Fix linting errors

This commit is contained in:
Gurjot Singh
2025-01-31 19:05:47 +05:30
parent 8a624e198a
commit 2894e8faf2
2 changed files with 15 additions and 15 deletions

View File

@@ -8,7 +8,6 @@ from sentence_transformers import SentenceTransformer
from openai import AzureOpenAI from openai import AzureOpenAI
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from lightrag.kg.faiss_impl import FaissVectorDBStorage
# Configure Logging # Configure Logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -20,14 +19,10 @@ AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
async def llm_model_func(
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
**kwargs
) -> str:
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# Create a client for AzureOpenAI # Create a client for AzureOpenAI
client = AzureOpenAI( client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY, api_key=AZURE_OPENAI_API_KEY,
@@ -56,12 +51,12 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray: async def embedding_func(texts: list[str]) -> np.ndarray:
model = SentenceTransformer('all-MiniLM-L6-v2') model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True) embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings return embeddings
def main():
def main():
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
# Initialize LightRAG with the LLM model function and embedding function # Initialize LightRAG with the LLM model function and embedding function
@@ -76,7 +71,7 @@ def main():
vector_storage="FaissVectorDBStorage", vector_storage="FaissVectorDBStorage",
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": 0.3 # Your desired threshold "cosine_better_than_threshold": 0.3 # Your desired threshold
} },
) )
# Insert the custom chunks into LightRAG # Insert the custom chunks into LightRAG

View File

@@ -22,6 +22,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
A Faiss-based Vector DB Storage for LightRAG. A Faiss-based Vector DB Storage for LightRAG.
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
""" """
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self): def __post_init__(self):
@@ -93,7 +94,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
pbar = tqdm_async(total=len(batches), desc="Generating embeddings", unit="batch") pbar = tqdm_async(
total=len(batches), desc="Generating embeddings", unit="batch"
)
async def wrapped_task(batch): async def wrapped_task(batch):
result = await self.embedding_func(batch) result = await self.embedding_func(batch)
@@ -200,7 +203,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
if to_remove: if to_remove:
self._remove_faiss_ids(to_remove) self._remove_faiss_ids(to_remove)
logger.info(f"Successfully deleted {len(to_remove)} vectors from {self.namespace}") logger.info(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
async def delete_entity(self, entity_name: str): async def delete_entity(self, entity_name: str):
""" """