Fix created_at handling in Chroma vector db

This commit is contained in:
yangdx
2025-05-02 16:21:48 +08:00
parent 011659b8bc
commit 0dc712e0bc

View File

@@ -114,11 +114,17 @@ class ChromaVectorDBStorage(BaseVectorStorage):
return return
try: try:
import time
current_time = int(time.time())
ids = list(data.keys()) ids = list(data.keys())
documents = [v["content"] for v in data.values()] documents = [v["content"] for v in data.values()]
metadatas = [ metadatas = [
{k: v for k, v in item.items() if k in self.meta_fields} {
or {"_default": "true"} **{k: v for k, v in item.items() if k in self.meta_fields},
"created_at": current_time,
}
or {"_default": "true", "created_at": current_time}
for item in data.values() for item in data.values()
] ]
@@ -183,6 +189,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
"id": results["ids"][0][i], "id": results["ids"][0][i],
"distance": 1 - results["distances"][0][i], "distance": 1 - results["distances"][0][i],
"content": results["documents"][0][i], "content": results["documents"][0][i],
"created_at": results["metadatas"][0][i].get("created_at"),
**results["metadatas"][0][i], **results["metadatas"][0][i],
} }
for i in range(len(results["ids"][0])) for i in range(len(results["ids"][0]))
@@ -298,6 +305,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
"id": result["ids"][0], "id": result["ids"][0],
"vector": result["embeddings"][0], "vector": result["embeddings"][0],
"content": result["documents"][0], "content": result["documents"][0],
"created_at": result["metadatas"][0].get("created_at"),
**result["metadatas"][0], **result["metadatas"][0],
} }
except Exception as e: except Exception as e:
@@ -331,6 +339,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
"id": result["ids"][i], "id": result["ids"][i],
"vector": result["embeddings"][i], "vector": result["embeddings"][i],
"content": result["documents"][i], "content": result["documents"][i],
"created_at": result["metadatas"][i].get("created_at"),
**result["metadatas"][i], **result["metadatas"][i],
} }
for i in range(len(result["ids"])) for i in range(len(result["ids"]))