198 lines
5.9 KiB
Python
198 lines
5.9 KiB
Python
|
|
import asyncio
|
||
|
|
import io
|
||
|
|
import os
|
||
|
|
import uuid
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import chromadb
|
||
|
|
import httpx
|
||
|
|
from fastapi import FastAPI, File, Form, UploadFile
|
||
|
|
from fastapi.responses import FileResponse, HTMLResponse
|
||
|
|
from fastapi.staticfiles import StaticFiles
|
||
|
|
from PIL import Image
|
||
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
|
||
|
|
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "clip-ViT-B-32")
|
||
|
|
SEED_COUNT = int(os.environ.get("SEED_COUNT", "100"))
|
||
|
|
IMAGE_DIR = Path("/app/images")
|
||
|
|
CHROMA_DIR = Path("/app/chroma_data")
|
||
|
|
|
||
|
|
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
app = FastAPI(title="Image Meaning DB")
|
||
|
|
|
||
|
|
chroma_client = chromadb.PersistentClient(path=str(CHROMA_DIR))
|
||
|
|
collection = chroma_client.get_or_create_collection(
|
||
|
|
name="images",
|
||
|
|
metadata={"hnsw:space": "cosine"},
|
||
|
|
)
|
||
|
|
|
||
|
|
print(f"Loading embedding model {EMBEDDING_MODEL}...")
|
||
|
|
embedder = SentenceTransformer(EMBEDDING_MODEL)
|
||
|
|
print(f"Model {EMBEDDING_MODEL} ready.")
|
||
|
|
|
||
|
|
|
||
|
|
def load_image(image_bytes: bytes) -> Image.Image:
|
||
|
|
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||
|
|
max_dim = 512
|
||
|
|
if max(img.size) > max_dim:
|
||
|
|
img.thumbnail((max_dim, max_dim), Image.LANCZOS)
|
||
|
|
return img
|
||
|
|
|
||
|
|
|
||
|
|
def get_embedding(img: Image.Image) -> list[float]:
|
||
|
|
vec = embedder.encode(img, convert_to_numpy=True, normalize_embeddings=True)
|
||
|
|
return vec.tolist()
|
||
|
|
|
||
|
|
|
||
|
|
async def _seed_database():
|
||
|
|
if SEED_COUNT <= 0:
|
||
|
|
print("Seed: SEED_COUNT is 0, skipping.")
|
||
|
|
return
|
||
|
|
if collection.count() > 0:
|
||
|
|
print(f"Seed: collection already has {collection.count()} images, skipping.")
|
||
|
|
return
|
||
|
|
print(f"Seed: fetching {SEED_COUNT} sample images from picsum.photos...")
|
||
|
|
try:
|
||
|
|
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||
|
|
resp = await client.get(
|
||
|
|
"https://picsum.photos/v2/list",
|
||
|
|
params={"page": 1, "limit": SEED_COUNT},
|
||
|
|
)
|
||
|
|
resp.raise_for_status()
|
||
|
|
photos = resp.json()
|
||
|
|
|
||
|
|
sem = asyncio.Semaphore(8)
|
||
|
|
|
||
|
|
async def fetch(p):
|
||
|
|
async with sem:
|
||
|
|
try:
|
||
|
|
r = await client.get(f"https://picsum.photos/id/{p['id']}/512/512")
|
||
|
|
r.raise_for_status()
|
||
|
|
return p, r.content
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Seed: fetch failed id={p['id']}: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
fetched = await asyncio.gather(*(fetch(p) for p in photos))
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Seed: list fetch failed: {e}")
|
||
|
|
return
|
||
|
|
|
||
|
|
added = 0
|
||
|
|
for item in fetched:
|
||
|
|
if item is None:
|
||
|
|
continue
|
||
|
|
p, image_bytes = item
|
||
|
|
try:
|
||
|
|
img = load_image(image_bytes)
|
||
|
|
embedding = await asyncio.to_thread(get_embedding, img)
|
||
|
|
image_id = f"seed-{p['id']}"
|
||
|
|
filename = f"{image_id}.jpg"
|
||
|
|
(IMAGE_DIR / filename).write_bytes(image_bytes)
|
||
|
|
collection.add(
|
||
|
|
ids=[image_id],
|
||
|
|
embeddings=[embedding],
|
||
|
|
metadatas=[{
|
||
|
|
"filename": filename,
|
||
|
|
"original_name": f"picsum-{p['id']}.jpg",
|
||
|
|
"description": f"Photo by {p['author']}",
|
||
|
|
}],
|
||
|
|
)
|
||
|
|
added += 1
|
||
|
|
if added % 10 == 0:
|
||
|
|
print(f"Seed: {added} images indexed...")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Seed: embed/store failed id={p['id']}: {e}")
|
||
|
|
print(f"Seed: finished, {added} images added.")
|
||
|
|
|
||
|
|
|
||
|
|
@app.on_event("startup")
|
||
|
|
async def on_startup():
|
||
|
|
asyncio.create_task(_seed_database())
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/api/submit")
|
||
|
|
async def submit_image(
|
||
|
|
file: UploadFile = File(...),
|
||
|
|
description: str = Form(""),
|
||
|
|
):
|
||
|
|
"""Upload an image, embed it, and store it."""
|
||
|
|
image_bytes = await file.read()
|
||
|
|
img = load_image(image_bytes)
|
||
|
|
|
||
|
|
embedding = get_embedding(img)
|
||
|
|
|
||
|
|
image_id = str(uuid.uuid4())
|
||
|
|
ext = Path(file.filename or "image.png").suffix or ".png"
|
||
|
|
filename = f"{image_id}{ext}"
|
||
|
|
filepath = IMAGE_DIR / filename
|
||
|
|
filepath.write_bytes(image_bytes)
|
||
|
|
|
||
|
|
collection.add(
|
||
|
|
ids=[image_id],
|
||
|
|
embeddings=[embedding],
|
||
|
|
metadatas=[{
|
||
|
|
"filename": filename,
|
||
|
|
"original_name": file.filename or "unknown",
|
||
|
|
"description": description,
|
||
|
|
}],
|
||
|
|
)
|
||
|
|
|
||
|
|
count = collection.count()
|
||
|
|
return {"id": image_id, "filename": filename, "total_images": count}
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/api/search")
|
||
|
|
async def search_image(file: UploadFile = File(...), n: int = 10):
|
||
|
|
"""Upload an image, embed it, and find nearest neighbors."""
|
||
|
|
image_bytes = await file.read()
|
||
|
|
img = load_image(image_bytes)
|
||
|
|
|
||
|
|
embedding = get_embedding(img)
|
||
|
|
|
||
|
|
count = collection.count()
|
||
|
|
if count == 0:
|
||
|
|
return {"results": [], "message": "No images in database yet."}
|
||
|
|
|
||
|
|
results = collection.query(
|
||
|
|
query_embeddings=[embedding],
|
||
|
|
n_results=min(n, count),
|
||
|
|
include=["distances", "metadatas"],
|
||
|
|
)
|
||
|
|
|
||
|
|
matches = []
|
||
|
|
for i, doc_id in enumerate(results["ids"][0]):
|
||
|
|
distance = results["distances"][0][i]
|
||
|
|
metadata = results["metadatas"][0][i]
|
||
|
|
similarity = 1 - distance # cosine distance -> similarity
|
||
|
|
matches.append({
|
||
|
|
"id": doc_id,
|
||
|
|
"filename": metadata["filename"],
|
||
|
|
"original_name": metadata["original_name"],
|
||
|
|
"description": metadata.get("description", ""),
|
||
|
|
"similarity": round(similarity, 4),
|
||
|
|
})
|
||
|
|
|
||
|
|
return {"results": matches}
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/images/{filename}")
|
||
|
|
async def get_image(filename: str):
|
||
|
|
"""Serve a stored image."""
|
||
|
|
filepath = IMAGE_DIR / filename
|
||
|
|
if not filepath.exists():
|
||
|
|
return {"error": "Image not found"}
|
||
|
|
return FileResponse(filepath)
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/stats")
|
||
|
|
async def stats():
|
||
|
|
"""Return database stats."""
|
||
|
|
return {"total_images": collection.count()}
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/", response_class=HTMLResponse)
|
||
|
|
async def index():
|
||
|
|
return Path("static/index.html").read_text()
|