Merge pull request #116 from Dormiveglia-elf/hotfix/embedding-dim
[hotfix-#75][embedding] Fix the potential embedding problem
This commit is contained in:
@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_dim():
|
||||||
|
test_text = ["This is a test sentence."]
|
||||||
|
embedding = await embedding_func(test_text)
|
||||||
|
embedding_dim = embedding.shape[1]
|
||||||
|
return embedding_dim
|
||||||
|
|
||||||
|
|
||||||
# function test
|
# function test
|
||||||
async def test_funcs():
|
async def test_funcs():
|
||||||
result = await llm_model_func("How are you?")
|
result = await llm_model_func("How are you?")
|
||||||
@@ -43,37 +50,46 @@ async def test_funcs():
|
|||||||
print("embedding_func: ", result)
|
print("embedding_func: ", result)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(test_funcs())
|
# asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
try:
|
||||||
|
embedding_dimension = await get_embedding_dim()
|
||||||
|
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
rag = LightRAG(
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
working_dir=WORKING_DIR,
|
rag.insert(f.read())
|
||||||
llm_model_func=llm_model_func,
|
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
# Perform local search
|
||||||
rag.insert(f.read())
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform naive search
|
# Perform global search
|
||||||
print(
|
print(
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform hybrid search
|
||||||
print(
|
print(
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
# Perform global search
|
if __name__ == "__main__":
|
||||||
print(
|
asyncio.run(main())
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform hybrid search
|
|
||||||
print(
|
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
|
||||||
)
|
|
Reference in New Issue
Block a user