from __future__ import annotations

import csv
import io
import json
import sqlite3
from datetime import datetime
from http import HTTPStatus
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any, Dict, Tuple
from urllib.parse import urlparse

from pricing import PriceInput, build_price_input, calculate_pricing

PROJECT_ROOT = Path(__file__).resolve().parent
STATIC_DIR = PROJECT_ROOT / "static"
DATA_DIR = PROJECT_ROOT / "data"
DB_PATH = DATA_DIR / "pricing_records.sqlite3"


def init_db() -> None:
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    with sqlite3.connect(DB_PATH) as conn:
        conn.execute(
            """
            CREATE TABLE IF NOT EXISTS pricing_records (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                created_at TEXT NOT NULL,
                product_name TEXT NOT NULL,
                tax_included_cost REAL NOT NULL,
                freight REAL NOT NULL,
                discount_rate REAL NOT NULL,
                platform_rate REAL NOT NULL,
                promotion_rate REAL NOT NULL,
                team_rate REAL NOT NULL,
                vat_rate REAL NOT NULL,
                break_even_price_raw REAL NOT NULL,
                break_even_price_ceil INTEGER NOT NULL,
                list_price_raw REAL NOT NULL,
                list_price_ceil INTEGER NOT NULL,
                discount_sale_price REAL NOT NULL,
                discount_break_even INTEGER NOT NULL,
                final_list_price REAL,
                note TEXT,
                input_json TEXT NOT NULL,
                result_json TEXT NOT NULL
            )
            """
        )
        conn.commit()


def record_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
    data = dict(row)
    data["discount_break_even"] = bool(data["discount_break_even"])
    return data


def save_record(item: PriceInput, result: Dict[str, Any]) -> int:
    created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with sqlite3.connect(DB_PATH) as conn:
        cur = conn.execute(
            """
            INSERT INTO pricing_records (
                created_at, product_name, tax_included_cost, freight, discount_rate,
                platform_rate, promotion_rate, team_rate, vat_rate,
                break_even_price_raw, break_even_price_ceil, list_price_raw,
                list_price_ceil, discount_sale_price, discount_break_even,
                final_list_price, note, input_json, result_json
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (
                created_at,
                item.product_name,
                item.tax_included_cost,
                item.freight,
                item.discount_rate,
                item.platform_rate,
                item.promotion_rate,
                item.team_rate,
                item.vat_rate,
                result["break_even_price_raw"],
                result["break_even_price_ceil"],
                result["list_price_raw"],
                result["list_price_ceil"],
                result["discount_sale_price"],
                1 if result["discount_break_even"] else 0,
                item.final_list_price,
                item.note,
                json.dumps(item.__dict__, ensure_ascii=False),
                json.dumps(result, ensure_ascii=False),
            ),
        )
        conn.commit()
        return int(cur.lastrowid)


def list_records() -> Dict[str, Any]:
    with sqlite3.connect(DB_PATH) as conn:
        conn.row_factory = sqlite3.Row
        rows = conn.execute(
            """
            SELECT id, created_at, product_name, tax_included_cost, freight,
                   discount_rate, platform_rate, promotion_rate, team_rate, vat_rate,
                   break_even_price_raw, break_even_price_ceil, list_price_raw,
                   list_price_ceil, discount_sale_price, discount_break_even,
                   final_list_price, note
            FROM pricing_records
            ORDER BY id DESC
            LIMIT 200
            """
        ).fetchall()
    return {"records": [record_to_dict(row) for row in rows], "db_path": str(DB_PATH)}


def export_csv_text() -> str:
    records = list_records()["records"]
    output = io.StringIO()
    fieldnames = [
        "id",
        "created_at",
        "product_name",
        "tax_included_cost",
        "freight",
        "discount_rate",
        "platform_rate",
        "promotion_rate",
        "team_rate",
        "vat_rate",
        "break_even_price_raw",
        "break_even_price_ceil",
        "list_price_raw",
        "list_price_ceil",
        "discount_sale_price",
        "discount_break_even",
        "final_list_price",
        "note",
    ]
    writer = csv.DictWriter(output, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(records)
    return output.getvalue()


class PricingHandler(SimpleHTTPRequestHandler):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, directory=str(STATIC_DIR), **kwargs)

    def log_message(self, fmt: str, *args: Any) -> None:
        print("[%s] %s" % (datetime.now().strftime("%H:%M:%S"), fmt % args))

    def _send_json(self, payload: Dict[str, Any], status: int = 200) -> None:
        body = json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8")
        self.send_response(status)
        self.send_header("Content-Type", "application/json; charset=utf-8")
        self.send_header("Content-Length", str(len(body)))
        self.end_headers()
        self.wfile.write(body)

    def _send_text(self, text: str, content_type: str, status: int = 200) -> None:
        body = text.encode("utf-8-sig" if content_type.startswith("text/csv") else "utf-8")
        self.send_response(status)
        self.send_header("Content-Type", content_type)
        self.send_header("Content-Length", str(len(body)))
        self.end_headers()
        self.wfile.write(body)

    def _read_json(self) -> Tuple[Dict[str, Any], str | None]:
        length = int(self.headers.get("Content-Length", "0") or 0)
        raw = self.rfile.read(length).decode("utf-8") if length else "{}"
        try:
            payload = json.loads(raw)
        except json.JSONDecodeError as exc:
            return {}, f"JSON 格式错误：{exc}"
        if not isinstance(payload, dict):
            return {}, "请求体必须是 JSON object"
        return payload, None

    def do_GET(self) -> None:
        path = urlparse(self.path).path
        if path == "/api/health":
            self._send_json({"ok": True, "db_path": str(DB_PATH)})
        elif path == "/api/records":
            self._send_json(list_records())
        elif path == "/api/records.csv":
            self._send_text(export_csv_text(), "text/csv; charset=utf-8")
        else:
            super().do_GET()

    def do_POST(self) -> None:
        path = urlparse(self.path).path
        payload, error = self._read_json()
        if error:
            self._send_json({"ok": False, "error": error}, HTTPStatus.BAD_REQUEST)
            return
        try:
            item = build_price_input(payload)
            result = calculate_pricing(item).to_dict()
        except ValueError as exc:
            self._send_json({"ok": False, "error": str(exc)}, HTTPStatus.BAD_REQUEST)
            return

        if path == "/api/calculate":
            self._send_json({"ok": True, "result": result})
        elif path == "/api/records":
            record_id = save_record(item, result)
            self._send_json({"ok": True, "id": record_id, "result": result, "db_path": str(DB_PATH)})
        else:
            self._send_json({"ok": False, "error": "接口不存在"}, HTTPStatus.NOT_FOUND)


def main() -> None:
    import argparse

    parser = argparse.ArgumentParser(description="良渚自研产品电商定价计算器")
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("--port", type=int, default=8766)
    args = parser.parse_args()

    init_db()
    server = ThreadingHTTPServer((args.host, args.port), PricingHandler)
    print(f"良渚自研产品电商定价计算器：http://{args.host}:{args.port}/")
    print(f"核算记录数据库：{DB_PATH}")
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        print("\nserver stopped")
    finally:
        server.server_close()


if __name__ == "__main__":
    main()
