#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Submit one Suno.cn instrumental generation intended to return two BGM variants.
Reference brief: rhythm between two user-provided MP3 references.
"""
import json
import os
import pathlib
import re
import sys
import traceback
import urllib.parse
import urllib.request
import urllib.error

BASE = "https://mcp.suno.cn"
API_KEY = os.environ.get("SUNO_CN_API_KEY") or ""
PROJECT = pathlib.Path("/Users/bot1/Volumes/root_for_ai/AI工作区/通用_产品宣传视频_古钱币杜邦纸钱袋包_20260530_1702")
STAMP = "20260607_020318"
OUT_DIR = PROJECT / "04_audio_bgm" / f"suno_reference_midtempo_{STAMP}"
LOG_DIR = PROJECT / "prompts" / "audio_bgm" / f"suno_reference_midtempo_{STAMP}"
OUT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

REQUEST = {
    "title": "Coin Pouch Midtempo Reference BGM",
    "prompt": (
        "Instrumental background music for a short vertical cultural creative product video. "
        "The rhythm should sit between two reference tracks: not too sparse or slow, not too fast or busy; "
        "a polished midtempo bounce around 155-165 BPM feel, with half-time warmth so it remains usable under product footage. "
        "Story arc: ancient coins wake up with tiny playful life, move forward together, a DuPont paper coin pouch arrives and catches them, "
        "then the pouch becomes a clean premium hero product shot. "
        "Mood: bright, warm, lively, refined, emotionally friendly, premium product-commercial, modern oriental craft. "
        "Arrangement: light plucked strings or marimba pulse, small metallic coin-click percussion, soft hand percussion, warm bass, gentle synth pad, delicate bell sparkles, subtle cinematic chord lift, satisfying clean ending. "
        "Keep it loop-friendly and easy to edit for a 15-second product video. No vocals, no singing, no lyrics, no rap, no heavy drums, no dark mystery, no horror, no empty museum ambient, no childish cartoon comedy."
    ),
    "tags": (
        "instrumental, short product ad BGM, midtempo, 160 BPM feel, bright, warm, lively, premium, "
        "modern oriental craft, plucked strings, marimba, metallic coin clicks, soft percussion, bell sparkles, no vocals"
    ),
    "mv": "chirp-fenix",
    "custom_mode": False,
    "instrumental": True,
}

REFERENCE_ANALYSIS = {
    "audio_caf06134b73d.mp3": {
        "duration_sec": 19.27,
        "estimated_bpm_candidates": [147, 146, 145, 144, 143],
        "onset_peaks_per_sec": 3.27,
        "role": "lower/steadier side of reference rhythm",
    },
    "audio_6acbeae60412.mp3": {
        "duration_sec": 24.74,
        "estimated_bpm_candidates": [178, 177, 176, 175, 174],
        "onset_peaks_per_sec": 3.03,
        "role": "faster/brighter side of reference rhythm",
    },
    "target": "midpoint energy, approximately 155-165 BPM feel, not slower than the first and not as fast/busy as the second",
}


def redact(text: str) -> str:
    if API_KEY:
        text = text.replace(API_KEY, "********")
    return re.sub(r"sk-[A-Za-z0-9_\-]+", "sk-********", text)


def request_json(method: str, path: str, body=None, timeout=60):
    url = BASE + path
    data = None
    headers = {"Authorization": f"Bearer {API_KEY}"}
    if body is not None:
        data = json.dumps(body, ensure_ascii=False).encode("utf-8")
        headers["Content-Type"] = "application/json; charset=utf-8"
    req = urllib.request.Request(url, data=data, headers=headers, method=method)
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            raw = resp.read().decode("utf-8", errors="replace")
            try:
                parsed = json.loads(raw)
            except Exception:
                parsed = {"_raw": raw}
            return {"ok": True, "status": resp.status, "headers": dict(resp.headers), "json": parsed, "url": url, "method": method}
    except urllib.error.HTTPError as e:
        raw = e.read().decode("utf-8", errors="replace")
        return {"ok": False, "status": e.code, "headers": dict(e.headers), "body": raw, "url": url, "method": method, "error_type": "HTTPError", "error": str(e)}
    except Exception as e:
        return {"ok": False, "status": None, "headers": {}, "body": "", "url": url, "method": method, "error_type": type(e).__name__, "error": str(e), "trace": traceback.format_exc()}


def error_record(err):
    return {
        "method": err.get("method"),
        "url": err.get("url"),
        "status": err.get("status"),
        "error_type": err.get("error_type"),
        "error": err.get("error"),
        "headers": err.get("headers"),
        "body": err.get("body"),
        "trace": err.get("trace"),
    }


def main():
    if not API_KEY:
        print(json.dumps({"ok": False, "error": "NO_API_KEY", "message": "SUNO_CN_API_KEY is not set"}, ensure_ascii=False))
        return 2

    (LOG_DIR / "reference_analysis.json").write_text(json.dumps(REFERENCE_ANALYSIS, ensure_ascii=False, indent=2), encoding="utf-8")
    (LOG_DIR / "generation_request.json").write_text(json.dumps(REQUEST, ensure_ascii=False, indent=2), encoding="utf-8")

    submit = request_json("POST", "/mcp/api/generate", body=REQUEST, timeout=60)
    summary = {
        "ok": False,
        "stamp": STAMP,
        "request": REQUEST,
        "reference_analysis": REFERENCE_ANALYSIS,
        "submit_status": submit.get("status"),
        "submitted_serials": [],
        "submit_raw": None,
        "query_status": None,
        "query_ok": None,
        "tasks": [],
        "saved": [],
        "out_dir": str(OUT_DIR),
        "log_dir": str(LOG_DIR),
    }

    if not submit.get("ok") or submit.get("status") != 200:
        summary["error"] = error_record(submit)
        (LOG_DIR / "submission_summary.json").write_text(redact(json.dumps(summary, ensure_ascii=False, indent=2)), encoding="utf-8")
        print(redact(json.dumps(summary, ensure_ascii=False, indent=2)))
        return 1

    data = submit.get("json") or {}
    serials = data.get("serial_nos") or data.get("serialNos") or data.get("serial_no") or []
    if isinstance(serials, str):
        serials = [serials]
    serials = [str(s) for s in serials if s]
    summary["submitted_serials"] = serials
    summary["submit_raw"] = data

    if serials:
        path_ids = ",".join(urllib.parse.quote(s, safe="") for s in serials)
        query = request_json("GET", f"/mcp/api/task/{path_ids}?wait=45", timeout=75)
        summary["query_status"] = query.get("status")
        summary["query_ok"] = query.get("ok")
        if not query.get("ok") or query.get("status") != 200:
            summary["query_error"] = error_record(query)
        else:
            qj = query.get("json") or {}
            tasks = qj.get("tasks") or qj.get("data") or qj.get("list") or []
            if isinstance(tasks, dict):
                tasks = [tasks]
            for index, task in enumerate(tasks if isinstance(tasks, list) else [], start=1):
                serial = str(task.get("serial_no") or task.get("serialNo") or task.get("id") or "")
                status = str(task.get("status") or "")
                play_url = task.get("play_url") or task.get("audio_url") or task.get("url")
                item = {
                    "variant_index": index,
                    "serial_no": serial,
                    "status": status,
                    "title": task.get("title"),
                    "duration": task.get("duration"),
                    "play_url": play_url,
                    "fail_reason": task.get("fail_reason"),
                }
                if status.lower() == "success" and play_url:
                    safe_title = re.sub(r"[^A-Za-z0-9_\-]+", "_", str(task.get("title") or "reference_midtempo")).strip("_")[:50] or "reference_midtempo"
                    local_path = OUT_DIR / f"coin_pouch_reference_midtempo_v{index}_{serial}_{safe_title}.mp3"
                    try:
                        with urllib.request.urlopen(play_url, timeout=90) as r:
                            content = r.read()
                        local_path.write_bytes(content)
                        item["local_path"] = str(local_path)
                        item["bytes"] = len(content)
                        summary["saved"].append(item)
                    except Exception as e:
                        item["download_error"] = f"{type(e).__name__}: {e}"
                summary["tasks"].append(item)

    summary["ok"] = bool(summary["submitted_serials"]) and all(t.get("status", "").lower() in {"success", "queued", "processing"} for t in summary["tasks"])
    (LOG_DIR / "submission_summary.json").write_text(redact(json.dumps(summary, ensure_ascii=False, indent=2)), encoding="utf-8")
    print(redact(json.dumps(summary, ensure_ascii=False, indent=2)))
    return 0 if summary["ok"] else 1


if __name__ == "__main__":
    sys.exit(main())
