example-projects/examples/image_meaning_db/backend/main.py

198 lines
5.9 KiB
Python

import asyncio
import io
import os
import uuid
from pathlib import Path
import chromadb
import httpx
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "clip-ViT-B-32")
SEED_COUNT = int(os.environ.get("SEED_COUNT", "100"))
IMAGE_DIR = Path("/app/images")
CHROMA_DIR = Path("/app/chroma_data")
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(title="Image Meaning DB")
chroma_client = chromadb.PersistentClient(path=str(CHROMA_DIR))
collection = chroma_client.get_or_create_collection(
name="images",
metadata={"hnsw:space": "cosine"},
)
print(f"Loading embedding model {EMBEDDING_MODEL}...")
embedder = SentenceTransformer(EMBEDDING_MODEL)
print(f"Model {EMBEDDING_MODEL} ready.")
def load_image(image_bytes: bytes) -> Image.Image:
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
max_dim = 512
if max(img.size) > max_dim:
img.thumbnail((max_dim, max_dim), Image.LANCZOS)
return img
def get_embedding(img: Image.Image) -> list[float]:
vec = embedder.encode(img, convert_to_numpy=True, normalize_embeddings=True)
return vec.tolist()
async def _seed_database():
if SEED_COUNT <= 0:
print("Seed: SEED_COUNT is 0, skipping.")
return
if collection.count() > 0:
print(f"Seed: collection already has {collection.count()} images, skipping.")
return
print(f"Seed: fetching {SEED_COUNT} sample images from picsum.photos...")
try:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
resp = await client.get(
"https://picsum.photos/v2/list",
params={"page": 1, "limit": SEED_COUNT},
)
resp.raise_for_status()
photos = resp.json()
sem = asyncio.Semaphore(8)
async def fetch(p):
async with sem:
try:
r = await client.get(f"https://picsum.photos/id/{p['id']}/512/512")
r.raise_for_status()
return p, r.content
except Exception as e:
print(f"Seed: fetch failed id={p['id']}: {e}")
return None
fetched = await asyncio.gather(*(fetch(p) for p in photos))
except Exception as e:
print(f"Seed: list fetch failed: {e}")
return
added = 0
for item in fetched:
if item is None:
continue
p, image_bytes = item
try:
img = load_image(image_bytes)
embedding = await asyncio.to_thread(get_embedding, img)
image_id = f"seed-{p['id']}"
filename = f"{image_id}.jpg"
(IMAGE_DIR / filename).write_bytes(image_bytes)
collection.add(
ids=[image_id],
embeddings=[embedding],
metadatas=[{
"filename": filename,
"original_name": f"picsum-{p['id']}.jpg",
"description": f"Photo by {p['author']}",
}],
)
added += 1
if added % 10 == 0:
print(f"Seed: {added} images indexed...")
except Exception as e:
print(f"Seed: embed/store failed id={p['id']}: {e}")
print(f"Seed: finished, {added} images added.")
@app.on_event("startup")
async def on_startup():
asyncio.create_task(_seed_database())
@app.post("/api/submit")
async def submit_image(
file: UploadFile = File(...),
description: str = Form(""),
):
"""Upload an image, embed it, and store it."""
image_bytes = await file.read()
img = load_image(image_bytes)
embedding = get_embedding(img)
image_id = str(uuid.uuid4())
ext = Path(file.filename or "image.png").suffix or ".png"
filename = f"{image_id}{ext}"
filepath = IMAGE_DIR / filename
filepath.write_bytes(image_bytes)
collection.add(
ids=[image_id],
embeddings=[embedding],
metadatas=[{
"filename": filename,
"original_name": file.filename or "unknown",
"description": description,
}],
)
count = collection.count()
return {"id": image_id, "filename": filename, "total_images": count}
@app.post("/api/search")
async def search_image(file: UploadFile = File(...), n: int = 10):
"""Upload an image, embed it, and find nearest neighbors."""
image_bytes = await file.read()
img = load_image(image_bytes)
embedding = get_embedding(img)
count = collection.count()
if count == 0:
return {"results": [], "message": "No images in database yet."}
results = collection.query(
query_embeddings=[embedding],
n_results=min(n, count),
include=["distances", "metadatas"],
)
matches = []
for i, doc_id in enumerate(results["ids"][0]):
distance = results["distances"][0][i]
metadata = results["metadatas"][0][i]
similarity = 1 - distance # cosine distance -> similarity
matches.append({
"id": doc_id,
"filename": metadata["filename"],
"original_name": metadata["original_name"],
"description": metadata.get("description", ""),
"similarity": round(similarity, 4),
})
return {"results": matches}
@app.get("/api/images/{filename}")
async def get_image(filename: str):
"""Serve a stored image."""
filepath = IMAGE_DIR / filename
if not filepath.exists():
return {"error": "Image not found"}
return FileResponse(filepath)
@app.get("/api/stats")
async def stats():
"""Return database stats."""
return {"total_images": collection.count()}
@app.get("/", response_class=HTMLResponse)
async def index():
return Path("static/index.html").read_text()