example-projects/examples/audio_meaning_db/backend/main.py

284 lines
9.3 KiB
Python
Raw Normal View History

import asyncio
import io
import os
import uuid
from pathlib import Path
import chromadb
import numpy as np
import soundfile as sf
import soxr
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, HTMLResponse, Response
from pydub import AudioSegment
from sentence_transformers import SentenceTransformer
from transformers import pipeline
ASR_MODEL = os.environ.get("ASR_MODEL", "openai/whisper-base")
TEXT_EMBEDDING_MODEL = os.environ.get(
"TEXT_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
SEGMENT_LENGTH_SEC = 60
MIN_TAIL_SEGMENT_SEC = 3.0
SEARCH_MAX_SEC = 60.5
TARGET_SAMPLE_RATE = 16000
AUDIO_DIR = Path("/app/audio")
CHROMA_DIR = Path("/app/chroma_data")
AUDIO_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(title="Audio Meaning DB")
chroma_client = chromadb.PersistentClient(path=str(CHROMA_DIR))
collection = chroma_client.get_or_create_collection(
name="audio_segments",
metadata={"hnsw:space": "cosine"},
)
print(f"Loading ASR model {ASR_MODEL}...")
asr_pipeline = pipeline(
"automatic-speech-recognition",
model=ASR_MODEL,
chunk_length_s=30,
device="cpu",
)
print(f"ASR model {ASR_MODEL} ready.")
print(f"Loading text embedding model {TEXT_EMBEDDING_MODEL}...")
text_embedder = SentenceTransformer(TEXT_EMBEDDING_MODEL)
print(f"Text embedding model {TEXT_EMBEDDING_MODEL} ready.")
def _load_with_pydub(audio_bytes: bytes) -> tuple[np.ndarray, int]:
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
if audio.channels > 1:
samples = samples.reshape(-1, audio.channels)
max_val = float(2 ** (audio.sample_width * 8 - 1))
samples = samples / max_val
return samples, audio.frame_rate
def load_audio(audio_bytes: bytes, filename: str) -> np.ndarray:
"""Decode audio bytes to mono float32 at TARGET_SAMPLE_RATE."""
ext = Path(filename).suffix.lower().lstrip(".")
data: np.ndarray
sr: int
if ext in ("wav", "flac", "ogg", "oga"):
try:
data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=False)
except Exception:
data, sr = _load_with_pydub(audio_bytes)
else:
data, sr = _load_with_pydub(audio_bytes)
if data.ndim > 1:
data = data.mean(axis=1)
data = np.asarray(data, dtype=np.float32)
if sr != TARGET_SAMPLE_RATE:
data = soxr.resample(data, sr, TARGET_SAMPLE_RATE, quality="HQ").astype(np.float32)
return data
def segment_audio(audio: np.ndarray) -> list[tuple[float, float, np.ndarray]]:
"""Split into ~60s segments. Returns [(start_sec, end_sec, samples), ...]."""
sr = TARGET_SAMPLE_RATE
total_sec = len(audio) / sr
if total_sec <= SEGMENT_LENGTH_SEC:
return [(0.0, total_sec, audio)]
segments: list[tuple[float, float, np.ndarray]] = []
segment_samples = SEGMENT_LENGTH_SEC * sr
num_full = len(audio) // segment_samples
for i in range(num_full):
start = i * segment_samples
end = start + segment_samples
segments.append((start / sr, end / sr, audio[start:end]))
tail_start = num_full * segment_samples
tail = audio[tail_start:]
tail_sec = len(tail) / sr
if tail_sec >= MIN_TAIL_SEGMENT_SEC:
segments.append((tail_start / sr, total_sec, tail))
return segments
def transcribe(audio: np.ndarray) -> str:
"""Transcribe mono 16kHz float32 audio to text."""
result = asr_pipeline({"array": audio, "sampling_rate": TARGET_SAMPLE_RATE})
return (result.get("text") or "").strip()
def embed_text(text: str) -> list[float]:
query = text if text else "(no speech detected)"
vec = text_embedder.encode(query, convert_to_numpy=True, normalize_embeddings=True)
return vec.tolist()
def fmt_time(sec: float) -> str:
sec = int(round(sec))
return f"{sec // 60}:{sec % 60:02d}"
@app.post("/api/submit")
async def submit_audio(
file: UploadFile = File(...),
description: str = Form(""),
):
audio_bytes = await file.read()
parent_id = str(uuid.uuid4())
ext = Path(file.filename or "audio.wav").suffix or ".wav"
filename = f"{parent_id}{ext}"
filepath = AUDIO_DIR / filename
filepath.write_bytes(audio_bytes)
try:
audio = load_audio(audio_bytes, file.filename or "")
except Exception as e:
filepath.unlink(missing_ok=True)
raise HTTPException(status_code=400, detail=f"Could not decode audio: {e}")
total_sec = len(audio) / TARGET_SAMPLE_RATE
segments = segment_audio(audio)
is_full_clip = len(segments) == 1
ids: list[str] = []
embeddings: list[list[float]] = []
metadatas: list[dict] = []
segment_summaries: list[dict] = []
for i, (start, end, samples) in enumerate(segments):
transcript = await asyncio.to_thread(transcribe, samples)
embedding = await asyncio.to_thread(embed_text, transcript)
segment_id = f"{parent_id}:{i}"
ids.append(segment_id)
embeddings.append(embedding)
metadatas.append({
"parent_id": parent_id,
"parent_filename": filename,
"parent_original_name": file.filename or "unknown",
"segment_index": i,
"start_sec": float(start),
"end_sec": float(end),
"parent_duration_sec": float(total_sec),
"description": description,
"transcript": transcript,
"is_full_clip": is_full_clip,
})
segment_summaries.append({
"segment_index": i,
"start_sec": float(start),
"end_sec": float(end),
"transcript": transcript,
})
collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)
return {
"parent_id": parent_id,
"filename": filename,
"segments_added": len(segments),
"total_segments": collection.count(),
"duration_sec": total_sec,
"segments": segment_summaries,
}
@app.post("/api/search")
async def search_audio(file: UploadFile = File(...), n: int = 10):
audio_bytes = await file.read()
try:
audio = load_audio(audio_bytes, file.filename or "")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not decode audio: {e}")
duration_sec = len(audio) / TARGET_SAMPLE_RATE
if duration_sec > SEARCH_MAX_SEC:
raise HTTPException(
status_code=400,
detail=f"Search audio must be at most {SEGMENT_LENGTH_SEC}s. Got {duration_sec:.1f}s.",
)
count = collection.count()
if count == 0:
return {"results": [], "message": "No audio in database yet.", "query_transcript": ""}
transcript = await asyncio.to_thread(transcribe, audio)
embedding = await asyncio.to_thread(embed_text, transcript)
results = collection.query(
query_embeddings=[embedding],
n_results=min(n, count),
include=["distances", "metadatas"],
)
matches = []
for i, doc_id in enumerate(results["ids"][0]):
distance = results["distances"][0][i]
metadata = results["metadatas"][0][i]
similarity = 1 - distance
matches.append({
"id": doc_id,
"parent_filename": metadata["parent_filename"],
"parent_original_name": metadata["parent_original_name"],
"segment_index": metadata["segment_index"],
"start_sec": metadata["start_sec"],
"end_sec": metadata["end_sec"],
"parent_duration_sec": metadata["parent_duration_sec"],
"description": metadata.get("description", ""),
"transcript": metadata.get("transcript", ""),
"is_full_clip": metadata.get("is_full_clip", True),
"similarity": round(similarity, 4),
})
return {"results": matches, "query_transcript": transcript}
@app.get("/api/audio/{filename}")
async def get_audio(filename: str):
filepath = AUDIO_DIR / filename
if not filepath.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return FileResponse(filepath)
@app.get("/api/segment/{parent_filename}")
async def get_segment(parent_filename: str, start: float, end: float):
filepath = AUDIO_DIR / parent_filename
if not filepath.exists():
raise HTTPException(status_code=404, detail="Audio not found")
if end <= start or start < 0:
raise HTTPException(status_code=400, detail="Invalid segment range")
audio_bytes = filepath.read_bytes()
audio = load_audio(audio_bytes, parent_filename)
sr = TARGET_SAMPLE_RATE
start_idx = max(0, int(start * sr))
end_idx = min(len(audio), int(end * sr))
if end_idx <= start_idx:
raise HTTPException(status_code=400, detail="Empty segment slice")
buf = io.BytesIO()
sf.write(buf, audio[start_idx:end_idx], sr, format="WAV", subtype="PCM_16")
buf.seek(0)
return Response(content=buf.read(), media_type="audio/wav")
@app.get("/api/stats")
async def stats():
total_segments = collection.count()
unique_parents = 0
if total_segments > 0:
all_meta = collection.get(include=["metadatas"])
unique_parents = len({m["parent_id"] for m in all_meta["metadatas"]})
return {"total_segments": total_segments, "total_clips": unique_parents}
@app.get("/", response_class=HTMLResponse)
async def index():
return Path("static/index.html").read_text()