284 lines
9.3 KiB
Python
284 lines
9.3 KiB
Python
import asyncio
|
|
import io
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
import chromadb
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import soxr
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
from fastapi.responses import FileResponse, HTMLResponse, Response
|
|
from pydub import AudioSegment
|
|
from sentence_transformers import SentenceTransformer
|
|
from transformers import pipeline
|
|
|
|
ASR_MODEL = os.environ.get("ASR_MODEL", "openai/whisper-base")
|
|
TEXT_EMBEDDING_MODEL = os.environ.get(
|
|
"TEXT_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
SEGMENT_LENGTH_SEC = 60
|
|
MIN_TAIL_SEGMENT_SEC = 3.0
|
|
SEARCH_MAX_SEC = 60.5
|
|
TARGET_SAMPLE_RATE = 16000
|
|
AUDIO_DIR = Path("/app/audio")
|
|
CHROMA_DIR = Path("/app/chroma_data")
|
|
|
|
AUDIO_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
app = FastAPI(title="Audio Meaning DB")
|
|
|
|
chroma_client = chromadb.PersistentClient(path=str(CHROMA_DIR))
|
|
collection = chroma_client.get_or_create_collection(
|
|
name="audio_segments",
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
|
|
print(f"Loading ASR model {ASR_MODEL}...")
|
|
asr_pipeline = pipeline(
|
|
"automatic-speech-recognition",
|
|
model=ASR_MODEL,
|
|
chunk_length_s=30,
|
|
device="cpu",
|
|
)
|
|
print(f"ASR model {ASR_MODEL} ready.")
|
|
|
|
print(f"Loading text embedding model {TEXT_EMBEDDING_MODEL}...")
|
|
text_embedder = SentenceTransformer(TEXT_EMBEDDING_MODEL)
|
|
print(f"Text embedding model {TEXT_EMBEDDING_MODEL} ready.")
|
|
|
|
|
|
def _load_with_pydub(audio_bytes: bytes) -> tuple[np.ndarray, int]:
|
|
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
|
|
samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
|
|
if audio.channels > 1:
|
|
samples = samples.reshape(-1, audio.channels)
|
|
max_val = float(2 ** (audio.sample_width * 8 - 1))
|
|
samples = samples / max_val
|
|
return samples, audio.frame_rate
|
|
|
|
|
|
def load_audio(audio_bytes: bytes, filename: str) -> np.ndarray:
|
|
"""Decode audio bytes to mono float32 at TARGET_SAMPLE_RATE."""
|
|
ext = Path(filename).suffix.lower().lstrip(".")
|
|
data: np.ndarray
|
|
sr: int
|
|
if ext in ("wav", "flac", "ogg", "oga"):
|
|
try:
|
|
data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=False)
|
|
except Exception:
|
|
data, sr = _load_with_pydub(audio_bytes)
|
|
else:
|
|
data, sr = _load_with_pydub(audio_bytes)
|
|
|
|
if data.ndim > 1:
|
|
data = data.mean(axis=1)
|
|
data = np.asarray(data, dtype=np.float32)
|
|
|
|
if sr != TARGET_SAMPLE_RATE:
|
|
data = soxr.resample(data, sr, TARGET_SAMPLE_RATE, quality="HQ").astype(np.float32)
|
|
|
|
return data
|
|
|
|
|
|
def segment_audio(audio: np.ndarray) -> list[tuple[float, float, np.ndarray]]:
|
|
"""Split into ~60s segments. Returns [(start_sec, end_sec, samples), ...]."""
|
|
sr = TARGET_SAMPLE_RATE
|
|
total_sec = len(audio) / sr
|
|
if total_sec <= SEGMENT_LENGTH_SEC:
|
|
return [(0.0, total_sec, audio)]
|
|
|
|
segments: list[tuple[float, float, np.ndarray]] = []
|
|
segment_samples = SEGMENT_LENGTH_SEC * sr
|
|
num_full = len(audio) // segment_samples
|
|
for i in range(num_full):
|
|
start = i * segment_samples
|
|
end = start + segment_samples
|
|
segments.append((start / sr, end / sr, audio[start:end]))
|
|
|
|
tail_start = num_full * segment_samples
|
|
tail = audio[tail_start:]
|
|
tail_sec = len(tail) / sr
|
|
if tail_sec >= MIN_TAIL_SEGMENT_SEC:
|
|
segments.append((tail_start / sr, total_sec, tail))
|
|
|
|
return segments
|
|
|
|
|
|
def transcribe(audio: np.ndarray) -> str:
|
|
"""Transcribe mono 16kHz float32 audio to text."""
|
|
result = asr_pipeline({"array": audio, "sampling_rate": TARGET_SAMPLE_RATE})
|
|
return (result.get("text") or "").strip()
|
|
|
|
|
|
def embed_text(text: str) -> list[float]:
|
|
query = text if text else "(no speech detected)"
|
|
vec = text_embedder.encode(query, convert_to_numpy=True, normalize_embeddings=True)
|
|
return vec.tolist()
|
|
|
|
|
|
def fmt_time(sec: float) -> str:
|
|
sec = int(round(sec))
|
|
return f"{sec // 60}:{sec % 60:02d}"
|
|
|
|
|
|
@app.post("/api/submit")
|
|
async def submit_audio(
|
|
file: UploadFile = File(...),
|
|
description: str = Form(""),
|
|
):
|
|
audio_bytes = await file.read()
|
|
|
|
parent_id = str(uuid.uuid4())
|
|
ext = Path(file.filename or "audio.wav").suffix or ".wav"
|
|
filename = f"{parent_id}{ext}"
|
|
filepath = AUDIO_DIR / filename
|
|
filepath.write_bytes(audio_bytes)
|
|
|
|
try:
|
|
audio = load_audio(audio_bytes, file.filename or "")
|
|
except Exception as e:
|
|
filepath.unlink(missing_ok=True)
|
|
raise HTTPException(status_code=400, detail=f"Could not decode audio: {e}")
|
|
|
|
total_sec = len(audio) / TARGET_SAMPLE_RATE
|
|
segments = segment_audio(audio)
|
|
is_full_clip = len(segments) == 1
|
|
|
|
ids: list[str] = []
|
|
embeddings: list[list[float]] = []
|
|
metadatas: list[dict] = []
|
|
segment_summaries: list[dict] = []
|
|
|
|
for i, (start, end, samples) in enumerate(segments):
|
|
transcript = await asyncio.to_thread(transcribe, samples)
|
|
embedding = await asyncio.to_thread(embed_text, transcript)
|
|
segment_id = f"{parent_id}:{i}"
|
|
ids.append(segment_id)
|
|
embeddings.append(embedding)
|
|
metadatas.append({
|
|
"parent_id": parent_id,
|
|
"parent_filename": filename,
|
|
"parent_original_name": file.filename or "unknown",
|
|
"segment_index": i,
|
|
"start_sec": float(start),
|
|
"end_sec": float(end),
|
|
"parent_duration_sec": float(total_sec),
|
|
"description": description,
|
|
"transcript": transcript,
|
|
"is_full_clip": is_full_clip,
|
|
})
|
|
segment_summaries.append({
|
|
"segment_index": i,
|
|
"start_sec": float(start),
|
|
"end_sec": float(end),
|
|
"transcript": transcript,
|
|
})
|
|
|
|
collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)
|
|
|
|
return {
|
|
"parent_id": parent_id,
|
|
"filename": filename,
|
|
"segments_added": len(segments),
|
|
"total_segments": collection.count(),
|
|
"duration_sec": total_sec,
|
|
"segments": segment_summaries,
|
|
}
|
|
|
|
|
|
@app.post("/api/search")
|
|
async def search_audio(file: UploadFile = File(...), n: int = 10):
|
|
audio_bytes = await file.read()
|
|
try:
|
|
audio = load_audio(audio_bytes, file.filename or "")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Could not decode audio: {e}")
|
|
|
|
duration_sec = len(audio) / TARGET_SAMPLE_RATE
|
|
if duration_sec > SEARCH_MAX_SEC:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Search audio must be at most {SEGMENT_LENGTH_SEC}s. Got {duration_sec:.1f}s.",
|
|
)
|
|
|
|
count = collection.count()
|
|
if count == 0:
|
|
return {"results": [], "message": "No audio in database yet.", "query_transcript": ""}
|
|
|
|
transcript = await asyncio.to_thread(transcribe, audio)
|
|
embedding = await asyncio.to_thread(embed_text, transcript)
|
|
|
|
results = collection.query(
|
|
query_embeddings=[embedding],
|
|
n_results=min(n, count),
|
|
include=["distances", "metadatas"],
|
|
)
|
|
|
|
matches = []
|
|
for i, doc_id in enumerate(results["ids"][0]):
|
|
distance = results["distances"][0][i]
|
|
metadata = results["metadatas"][0][i]
|
|
similarity = 1 - distance
|
|
matches.append({
|
|
"id": doc_id,
|
|
"parent_filename": metadata["parent_filename"],
|
|
"parent_original_name": metadata["parent_original_name"],
|
|
"segment_index": metadata["segment_index"],
|
|
"start_sec": metadata["start_sec"],
|
|
"end_sec": metadata["end_sec"],
|
|
"parent_duration_sec": metadata["parent_duration_sec"],
|
|
"description": metadata.get("description", ""),
|
|
"transcript": metadata.get("transcript", ""),
|
|
"is_full_clip": metadata.get("is_full_clip", True),
|
|
"similarity": round(similarity, 4),
|
|
})
|
|
|
|
return {"results": matches, "query_transcript": transcript}
|
|
|
|
|
|
@app.get("/api/audio/{filename}")
|
|
async def get_audio(filename: str):
|
|
filepath = AUDIO_DIR / filename
|
|
if not filepath.exists():
|
|
raise HTTPException(status_code=404, detail="Audio not found")
|
|
return FileResponse(filepath)
|
|
|
|
|
|
@app.get("/api/segment/{parent_filename}")
|
|
async def get_segment(parent_filename: str, start: float, end: float):
|
|
filepath = AUDIO_DIR / parent_filename
|
|
if not filepath.exists():
|
|
raise HTTPException(status_code=404, detail="Audio not found")
|
|
if end <= start or start < 0:
|
|
raise HTTPException(status_code=400, detail="Invalid segment range")
|
|
|
|
audio_bytes = filepath.read_bytes()
|
|
audio = load_audio(audio_bytes, parent_filename)
|
|
sr = TARGET_SAMPLE_RATE
|
|
start_idx = max(0, int(start * sr))
|
|
end_idx = min(len(audio), int(end * sr))
|
|
if end_idx <= start_idx:
|
|
raise HTTPException(status_code=400, detail="Empty segment slice")
|
|
|
|
buf = io.BytesIO()
|
|
sf.write(buf, audio[start_idx:end_idx], sr, format="WAV", subtype="PCM_16")
|
|
buf.seek(0)
|
|
return Response(content=buf.read(), media_type="audio/wav")
|
|
|
|
|
|
@app.get("/api/stats")
|
|
async def stats():
|
|
total_segments = collection.count()
|
|
unique_parents = 0
|
|
if total_segments > 0:
|
|
all_meta = collection.get(include=["metadatas"])
|
|
unique_parents = len({m["parent_id"] for m in all_meta["metadatas"]})
|
|
return {"total_segments": total_segments, "total_clips": unique_parents}
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def index():
|
|
return Path("static/index.html").read_text()
|