#!/usr/bin/env python3
"""
Fumoca Frame Extraction Engine

Video-first, splat-oriented frame selection:
- samples candidate frames
- scores them
- filters out poor frames
- enforces spacing and diversity
- writes selected frames + manifest

Dependencies:
pip install opencv-python numpy
"""

from __future__ import annotations

import argparse
import json
import logging
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import cv2
import numpy as np

__version__ = "10.1.0"

log = logging.getLogger("fumoca")


@dataclass
class ExtractionProfile:
    name: str
    sample_fps: float
    max_candidates: int
    max_selected: int
    min_selected: int
    min_gap_seconds: float
    blur_threshold: float
    brightness_min: float
    brightness_max: float
    duplicate_mse_threshold: float
    feature_match_bias: bool
    centering_bias: float
    blur_weight: float
    exposure_weight: float
    centering_weight: float
    spacing_weight: float
    # Dedup / recovery tuning — previously hardcoded magic numbers
    feature_min_distance: float       # ORB mean distance below this = too similar (was 26.0)
    blur_normalize_saturation: float  # normalize_blur saturation divisor (was 2.2)
    recovery_gap_factor: float        # min_gap multiplier relaxed in recovery pass (was 0.75)


PROFILES: Dict[str, ExtractionProfile] = {
    "fast": ExtractionProfile(
        name="fast",
        sample_fps=2.5,
        max_candidates=90,
        max_selected=24,
        min_selected=16,
        min_gap_seconds=0.25,
        blur_threshold=80.0,
        brightness_min=40.0,
        brightness_max=215.0,
        duplicate_mse_threshold=18.0,
        feature_match_bias=False,
        centering_bias=0.15,
        blur_weight=0.45,
        exposure_weight=0.25,
        centering_weight=0.10,
        spacing_weight=0.20,
        feature_min_distance=26.0,
        blur_normalize_saturation=2.2,
        recovery_gap_factor=0.75,
    ),
    "standard": ExtractionProfile(
        name="standard",
        sample_fps=4.0,
        max_candidates=140,
        max_selected=42,
        min_selected=24,
        min_gap_seconds=0.20,
        blur_threshold=90.0,
        brightness_min=38.0,
        brightness_max=218.0,
        duplicate_mse_threshold=16.0,
        feature_match_bias=True,
        centering_bias=0.20,
        blur_weight=0.45,
        exposure_weight=0.20,
        centering_weight=0.15,
        spacing_weight=0.20,
        feature_min_distance=26.0,
        blur_normalize_saturation=2.2,
        recovery_gap_factor=0.75,
    ),
    "pro": ExtractionProfile(
        name="pro",
        sample_fps=6.0,
        max_candidates=220,
        max_selected=72,
        min_selected=36,
        min_gap_seconds=0.16,
        blur_threshold=95.0,
        brightness_min=35.0,
        brightness_max=220.0,
        duplicate_mse_threshold=14.0,
        feature_match_bias=True,
        centering_bias=0.25,
        blur_weight=0.42,
        exposure_weight=0.18,
        centering_weight=0.18,
        spacing_weight=0.22,
        feature_min_distance=24.0,
        blur_normalize_saturation=2.2,
        recovery_gap_factor=0.75,
    ),
}


@dataclass
class FrameMetrics:
    index: int
    time_sec: float
    filename: str
    width: int
    height: int
    blur_score: float
    brightness_mean: float
    brightness_std: float
    exposure_score: float
    centering_score: float
    keep_quality_gate: bool
    quality_score: float
    spacing_score: float = 0.0          # diversity bonus — kept separate, not folded into quality_score
    final_score: float = 0.0            # quality_score + spacing_score, used for selection only
    duplicate_mse_from_prev_kept: Optional[float] = None
    feature_distance_from_prev_kept: Optional[float] = None


def ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def variance_of_laplacian(image_bgr: np.ndarray) -> float:
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    return float(cv2.Laplacian(gray, cv2.CV_64F).var())


def brightness_stats(image_bgr: np.ndarray) -> Tuple[float, float]:
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    return float(np.mean(gray)), float(np.std(gray))


def exposure_score(mean_brightness: float, brightness_std: float, p: ExtractionProfile) -> float:
    # Peak around the midpoint of the preferred range, with a small bonus for moderate contrast.
    target = (p.brightness_min + p.brightness_max) / 2.0
    half_range = max((p.brightness_max - p.brightness_min) / 2.0, 1.0)
    brightness_fit = max(0.0, 1.0 - abs(mean_brightness - target) / half_range)
    contrast_bonus = min(brightness_std / 64.0, 1.0)
    return 0.75 * brightness_fit + 0.25 * contrast_bonus


def centering_score(image_bgr: np.ndarray) -> float:
    # Lightweight saliency proxy: edge density weighted by closeness to center.
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 70, 180)
    h, w = gray.shape[:2]
    yy, xx = np.mgrid[0:h, 0:w]
    cx, cy = w / 2.0, h / 2.0
    nx = (xx - cx) / max(cx, 1.0)
    ny = (yy - cy) / max(cy, 1.0)
    dist = np.sqrt(nx ** 2 + ny ** 2)
    weights = np.clip(1.0 - dist, 0.0, 1.0)
    score = (edges.astype(np.float32) / 255.0 * weights).mean()
    return float(np.clip(score * 8.0, 0.0, 1.0))


def mse(a: np.ndarray, b: np.ndarray) -> float:
    a_small = cv2.resize(a, (256, 256))
    b_small = cv2.resize(b, (256, 256))
    diff = a_small.astype(np.float32) - b_small.astype(np.float32)
    return float(np.mean(diff ** 2))


def feature_distance(a: np.ndarray, b: np.ndarray) -> Optional[float]:
    try:
        orb = cv2.ORB_create(600)
        gray_a = cv2.cvtColor(a, cv2.COLOR_BGR2GRAY)
        gray_b = cv2.cvtColor(b, cv2.COLOR_BGR2GRAY)
        kpa, desa = orb.detectAndCompute(gray_a, None)
        kpb, desb = orb.detectAndCompute(gray_b, None)
        if desa is None or desb is None:
            return None
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        matches = bf.match(desa, desb)
        if not matches:
            return None
        distances = [m.distance for m in matches]
        # Lower mean distance indicates higher similarity.
        return float(np.mean(distances))
    except Exception:
        return None


def normalize_blur(blur_score: float, threshold: float, saturation: float = 2.2) -> float:
    if blur_score <= 0:
        return 0.0
    # Saturates as blur_score exceeds threshold.
    # saturation comes from profile.blur_normalize_saturation.
    ratio = blur_score / max(threshold, 1e-6)
    return float(np.clip(ratio / max(saturation, 1e-6), 0.0, 1.0))


def exposure_brightness_proxy(exposure_score_value: float, p: ExtractionProfile) -> float:
    # Maps a normalised exposure score back to an approximate brightness value
    # for use in the quality gate check inside score_candidates.
    return p.brightness_min + exposure_score_value * (p.brightness_max - p.brightness_min)


def _apply_rotation(frame: np.ndarray, rotation_code: int) -> np.ndarray:
    """Rotate frame to correct for mobile portrait-mode metadata."""
    if rotation_code == 90:
        return cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
    if rotation_code == 180:
        return cv2.rotate(frame, cv2.ROTATE_180)
    if rotation_code == 270:
        return cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
    return frame


def extract_candidates(
    video_path: Path,
    out_candidates_dir: Path,
    profile: ExtractionProfile,
) -> List[Tuple[int, float, np.ndarray]]:
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f"Could not open video: {video_path}")

    try:
        source_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        duration_sec = total_frames / max(source_fps, 1e-6) if total_frames > 0 else None

        # Portrait-mode rotation metadata (common on Android/iOS captures)
        rotation = int(cap.get(cv2.CAP_PROP_ORIENTATION_META) or 0)
        if rotation not in (0, 90, 180, 270):
            rotation = 0

        if duration_sec is not None:
            log.info("Video: %.1fs at %.2f fps (rotation=%d°)", duration_sec, source_fps, rotation)
            if duration_sec < 3.0:
                log.warning("Video is very short (%.1fs) — may yield fewer frames than min_selected=%d", duration_sec, profile.min_selected)
            if duration_sec > 120.0:
                log.warning("Video is long (%.1fs) — only first ~%.1fs will be sampled (max_candidates=%d)", duration_sec, profile.max_candidates / max(profile.sample_fps, 0.1), profile.max_candidates)

        frame_interval = max(int(round(source_fps / max(profile.sample_fps, 0.1))), 1)
        ensure_dir(out_candidates_dir)
        candidates: List[Tuple[int, float, np.ndarray]] = []

        frame_idx = 0
        kept_idx = 0
        while True:
            ok, frame = cap.read()
            if not ok:
                break

            if frame_idx % frame_interval == 0:
                if rotation:
                    frame = _apply_rotation(frame, rotation)
                time_sec = frame_idx / max(source_fps, 1e-6)
                candidates.append((kept_idx, time_sec, frame.copy()))
                kept_idx += 1
                if len(candidates) >= profile.max_candidates:
                    break
            frame_idx += 1

        log.info("Extracted %d candidate frames", len(candidates))
        return candidates

    finally:
        cap.release()


def score_candidates(candidates: List[Tuple[int, float, np.ndarray]], profile: ExtractionProfile) -> List[FrameMetrics]:
    metrics: List[FrameMetrics] = []
    for idx, time_sec, frame in candidates:
        h, w = frame.shape[:2]
        blur_score = variance_of_laplacian(frame)
        brightness_mean, brightness_std = brightness_stats(frame)
        exposure = exposure_score(brightness_mean, brightness_std, profile)
        centering = centering_score(frame)
        keep_gate = (
            blur_score >= profile.blur_threshold
            and profile.brightness_min <= brightness_mean <= profile.brightness_max
        )
        quality = (
            normalize_blur(blur_score, profile.blur_threshold, profile.blur_normalize_saturation) * profile.blur_weight
            + exposure * profile.exposure_weight
            + centering * profile.centering_weight
        )
        metrics.append(
            FrameMetrics(
                index=idx,
                time_sec=float(time_sec),
                filename=f"candidate_{idx:04d}.jpg",
                width=int(w),
                height=int(h),
                blur_score=float(blur_score),
                brightness_mean=float(brightness_mean),
                brightness_std=float(brightness_std),
                exposure_score=float(exposure),
                centering_score=float(centering),
                keep_quality_gate=bool(keep_gate),
                quality_score=float(np.clip(quality, 0.0, 1.0)),
            )
        )
    passed = sum(1 for m in metrics if m.keep_quality_gate)
    log.info("Scored %d candidates — %d passed quality gate", len(metrics), passed)
    return metrics


def select_frames(
    candidates: List[Tuple[int, float, np.ndarray]],
    metrics: List[FrameMetrics],
    profile: ExtractionProfile,
) -> List[FrameMetrics]:
    by_index = {m.index: m for m in metrics}
    sorted_candidates = sorted(candidates, key=lambda c: by_index[c[0]].quality_score, reverse=True)

    selected: List[FrameMetrics] = []
    selected_indices: set = set()          # Fix: track by index, not object identity
    selected_frames_list: List[np.ndarray] = []

    for idx, time_sec, frame in sorted_candidates:
        m = by_index[idx]

        # Prefer quality-passing frames first.
        if not m.keep_quality_gate and len(selected) < profile.min_selected:
            continue

        # Time spacing check
        too_close = any(abs(m.time_sec - s.time_sec) < profile.min_gap_seconds for s in selected)
        if too_close:
            continue

        # Dedup: compare against the nearest already-selected frame in time
        # (not just the last added, which is quality-ordered not time-ordered)
        if selected_frames_list:
            nearest_selected = min(selected, key=lambda s: abs(s.time_sec - m.time_sec))
            nearest_frame_idx = nearest_selected.index
            nearest_frame = next(f for i, _, f in candidates if i == nearest_frame_idx)

            d_mse = mse(frame, nearest_frame)
            m.duplicate_mse_from_prev_kept = float(d_mse)
            if d_mse < profile.duplicate_mse_threshold:
                continue

            if profile.feature_match_bias:
                feat_dist = feature_distance(frame, nearest_frame)
                m.feature_distance_from_prev_kept = feat_dist
                # Use profile field instead of hardcoded 26.0
                if feat_dist is not None and feat_dist < profile.feature_min_distance:
                    continue

        # Diversity bonus: stored separately — does NOT mutate quality_score
        if selected:
            nearest_gap = min(abs(m.time_sec - s.time_sec) for s in selected)
        else:
            nearest_gap = profile.min_gap_seconds
        spacing_bonus = min(nearest_gap / max(profile.min_gap_seconds * 4.0, 1e-6), 1.0)
        m.spacing_score = float(np.clip(spacing_bonus * profile.spacing_weight, 0.0, 1.0))
        m.final_score = float(np.clip(m.quality_score + m.spacing_score, 0.0, 1.0))

        selected.append(m)
        selected_indices.add(m.index)
        selected_frames_list.append(frame)
        if len(selected) >= profile.max_selected:
            break

    # Recovery pass if too few selected.
    if len(selected) < profile.min_selected:
        log.warning("Only %d frames selected (min=%d) — running recovery pass", len(selected), profile.min_selected)
        remaining = [c for c in sorted_candidates if c[0] not in selected_indices]  # Fix: index set
        for idx, time_sec, frame in remaining:
            m = by_index[idx]
            relaxed_gap = profile.min_gap_seconds * profile.recovery_gap_factor  # Fix: profile field
            too_close = any(abs(m.time_sec - s.time_sec) < relaxed_gap for s in selected)
            if too_close:
                continue
            selected.append(m)
            selected_indices.add(m.index)
            if len(selected) >= profile.min_selected:
                break

    log.info("Selected %d frames", len(selected))
    return sorted(selected, key=lambda x: x.time_sec)


def write_outputs(
    candidates: List[Tuple[int, float, np.ndarray]],
    selected_metrics: List[FrameMetrics],
    all_metrics: List[FrameMetrics],
    out_dir: Path,
    profile: ExtractionProfile,
    video_path: Path,
) -> Dict:
    selected_dir = out_dir / "selected_frames"
    ensure_dir(selected_dir)

    frame_by_index = {idx: frame for idx, _, frame in candidates}
    for i, m in enumerate(selected_metrics):
        frame = frame_by_index[m.index]
        out_name = f"frame_{i:04d}_t{m.time_sec:06.2f}.jpg"
        ok = cv2.imwrite(str(selected_dir / out_name), frame)
        if not ok:
            raise RuntimeError(f"Failed to write frame image: {selected_dir / out_name} — check disk space and path")
        m.filename = out_name

    manifest = {
        "fumoca_version": __version__,
        "video_path": str(video_path),
        "profile": asdict(profile),
        "selected_count": len(selected_metrics),
        "candidate_count": len(all_metrics),
        "selected_frames": [asdict(m) for m in selected_metrics],
        "all_candidates": [asdict(m) for m in all_metrics],
    }

    summary = {
        "fumoca_version": __version__,
        "video_path": str(video_path),
        "profile": profile.name,
        "candidate_count": len(all_metrics),
        "selected_count": len(selected_metrics),
        "selected_times_sec": [round(m.time_sec, 3) for m in selected_metrics],
        "avg_blur_selected": round(float(np.mean([m.blur_score for m in selected_metrics])) if selected_metrics else 0.0, 3),
        "avg_quality_selected": round(float(np.mean([m.quality_score for m in selected_metrics])) if selected_metrics else 0.0, 3),
        "avg_final_score_selected": round(float(np.mean([m.final_score for m in selected_metrics])) if selected_metrics else 0.0, 3),
    }

    with open(out_dir / "manifest.json", "w", encoding="utf-8") as f:
        json.dump(manifest, f, indent=2)
    with open(out_dir / "summary.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2)

    return summary


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Fumoca frame extraction engine")
    parser.add_argument("--input", required=True, help="Path to input video")
    parser.add_argument("--output", required=True, help="Output directory")
    parser.add_argument("--profile", default="standard", choices=sorted(PROFILES.keys()))
    parser.add_argument("--version", action="version", version=f"fumoca {__version__}")
    parser.add_argument(
        "--log-level",
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        help="Logging verbosity (default: INFO)",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    logging.basicConfig(
        level=getattr(logging, args.log_level),
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        datefmt="%H:%M:%S",
    )

    video_path = Path(args.input)
    out_dir = Path(args.output)
    ensure_dir(out_dir)

    if not video_path.exists():
        raise FileNotFoundError(f"Input video not found: {video_path}")

    profile = PROFILES[args.profile]
    log.info("Fumoca %s — profile=%s input=%s", __version__, profile.name, video_path.name)

    candidates = extract_candidates(video_path, out_dir / "candidates", profile)
    all_metrics = score_candidates(candidates, profile)
    selected = select_frames(candidates, all_metrics, profile)
    summary = write_outputs(candidates, selected, all_metrics, out_dir, profile, video_path)

    log.info("Done — %d/%d frames selected → %s", summary["selected_count"], summary["candidate_count"], out_dir)
    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
