from __future__ import annotations

from dataclasses import asdict, dataclass
from math import ceil
from typing import Any, Dict, List, Optional


PRICING_TYPE_SELF_DEVELOPED = "liangzhu_self_developed"
PRICING_TYPE_FACTORY_PURCHASE = "liangzhu_factory_purchase"

PRICING_TYPE_CONFIG: Dict[str, Dict[str, Any]] = {
    PRICING_TYPE_SELF_DEVELOPED: {
        "label": "良渚自研产品",
        "freight": 8.0,
        "ip_rate": 0.0,
        "platform_rate": 0.061,
        "promotion_rate": 0.10,
        "team_rate": 0.28,
        "vat_rate": 0.13,
        "target_profit_rate": 0.0,
        "discount_rate": 0.85,
        "formula_note": "最低成交价 = (运费 + 含税成本×100/113) / (1 - 平台 - 推广 - 团队 - 13/113)",
    },
    PRICING_TYPE_FACTORY_PURCHASE: {
        "label": "厂家采购成本产品",
        "freight": 5.5,
        "ip_rate": 0.03,
        "platform_rate": 0.065,
        "promotion_rate": 0.10,
        "team_rate": 0.10,
        "vat_rate": 0.13,
        "target_profit_rate": 0.10,
        "discount_rate": 0.85,
        "formula_note": "最低成交价 = (运费 + 厂家采购含税成本×100/113) / (1 - IP提成 - 平台 - 推广 - 团队 - 目标利润 - 13/113)",
    },
}


@dataclass(frozen=True)
class PriceInput:
    product_name: str
    tax_included_cost: float
    pricing_type: str = PRICING_TYPE_SELF_DEVELOPED
    freight: float = 8.0
    ip_rate: float = 0.0
    platform_rate: float = 0.061
    promotion_rate: float = 0.10
    team_rate: float = 0.28
    vat_rate: float = 0.13
    target_profit_rate: float = 0.0
    discount_rate: float = 0.85
    note: str = ""
    final_list_price: Optional[float] = None


@dataclass(frozen=True)
class PriceResult:
    product_name: str
    pricing_type: str
    pricing_type_label: str
    total_cost_rate: float
    target_profit_rate: float
    vat_factor: float
    break_even_price_raw: float
    break_even_price_ceil: int
    list_price_raw: float
    list_price_ceil: int
    discount_sale_price: float
    ip_fee: float
    platform_fee: float
    promotion_fee: float
    team_fee: float
    output_vat: float
    input_vat: float
    net_vat: float
    target_profit_amount: float
    profit_at_break_even_raw: float
    profit_at_discount_sale: float
    discount_break_even: bool
    suggested_price_tiers: List[int]
    formula_note: str

    def to_dict(self) -> Dict[str, Any]:
        data = asdict(self)
        for key, value in list(data.items()):
            if isinstance(value, float):
                data[key] = round(value, 6)
        return data


def normalize_rate(value: Any, default: float) -> float:
    if value in (None, ""):
        return default
    rate = float(value)
    if rate > 1:
        rate = rate / 100
    return rate


def normalize_pricing_type(value: Any) -> str:
    pricing_type = str(value or PRICING_TYPE_SELF_DEVELOPED).strip()
    if pricing_type not in PRICING_TYPE_CONFIG:
        allowed = "、".join(cfg["label"] for cfg in PRICING_TYPE_CONFIG.values())
        raise ValueError(f"核算类型不支持，可选：{allowed}")
    return pricing_type


def get_pricing_type_label(pricing_type: str) -> str:
    return str(PRICING_TYPE_CONFIG.get(pricing_type, PRICING_TYPE_CONFIG[PRICING_TYPE_SELF_DEVELOPED])["label"])


def build_price_input(payload: Dict[str, Any]) -> PriceInput:
    pricing_type = normalize_pricing_type(payload.get("pricing_type"))
    defaults = PRICING_TYPE_CONFIG[pricing_type]

    name = str(payload.get("product_name") or "").strip()
    if not name:
        raise ValueError("产品名称不能为空")
    try:
        cost = float(payload.get("tax_included_cost"))
    except (TypeError, ValueError):
        raise ValueError("含税成本价必须是数字")
    if cost <= 0:
        raise ValueError("含税成本价必须大于 0")

    freight = float(payload.get("freight", defaults["freight"]) or 0)
    if freight < 0:
        raise ValueError("运费不能为负数")

    discount_rate = normalize_rate(payload.get("discount_rate"), defaults["discount_rate"])
    if discount_rate <= 0 or discount_rate > 1:
        raise ValueError("折扣率必须在 0 到 1 之间，例如 0.85")

    final_list_price = payload.get("final_list_price")
    if final_list_price in (None, ""):
        final_value = None
    else:
        final_value = float(final_list_price)
        if final_value <= 0:
            raise ValueError("最终采用一口价必须大于 0")

    rates = {
        "ip_rate": normalize_rate(payload.get("ip_rate"), defaults["ip_rate"]),
        "platform_rate": normalize_rate(payload.get("platform_rate"), defaults["platform_rate"]),
        "promotion_rate": normalize_rate(payload.get("promotion_rate"), defaults["promotion_rate"]),
        "team_rate": normalize_rate(payload.get("team_rate"), defaults["team_rate"]),
        "vat_rate": normalize_rate(payload.get("vat_rate"), defaults["vat_rate"]),
        "target_profit_rate": normalize_rate(payload.get("target_profit_rate"), defaults["target_profit_rate"]),
    }
    for label, rate in rates.items():
        if rate < 0:
            raise ValueError(f"{label} 不能为负数")

    return PriceInput(
        product_name=name,
        tax_included_cost=cost,
        pricing_type=pricing_type,
        freight=freight,
        ip_rate=rates["ip_rate"],
        platform_rate=rates["platform_rate"],
        promotion_rate=rates["promotion_rate"],
        team_rate=rates["team_rate"],
        vat_rate=rates["vat_rate"],
        target_profit_rate=rates["target_profit_rate"],
        discount_rate=discount_rate,
        note=str(payload.get("note") or "").strip(),
        final_list_price=final_value,
    )


def calculate_profit(sale_price: float, item: PriceInput) -> float:
    vat_factor = item.vat_rate / (1 + item.vat_rate)
    total_rate = item.ip_rate + item.platform_rate + item.promotion_rate + item.team_rate
    net_vat = (sale_price - item.tax_included_cost) * vat_factor
    return sale_price - sale_price * total_rate - item.freight - item.tax_included_cost - net_vat


def calculate_pricing(item: PriceInput) -> PriceResult:
    if item.tax_included_cost <= 0:
        raise ValueError("含税成本价必须大于 0")
    if item.freight < 0:
        raise ValueError("运费不能为负数")
    if item.discount_rate <= 0 or item.discount_rate > 1:
        raise ValueError("折扣率必须在 0 到 1 之间，例如 0.85")
    if item.pricing_type not in PRICING_TYPE_CONFIG:
        normalize_pricing_type(item.pricing_type)

    total_rate = item.ip_rate + item.platform_rate + item.promotion_rate + item.team_rate
    vat_factor = item.vat_rate / (1 + item.vat_rate)
    denominator = 1 - total_rate - item.target_profit_rate - vat_factor
    if denominator <= 0:
        raise ValueError("费率和目标利润合计过高，无法计算最低成交价")

    break_even_raw = (item.freight + item.tax_included_cost * (1 - vat_factor)) / denominator
    break_even_ceil = ceil(break_even_raw)
    list_price_raw = break_even_raw / item.discount_rate
    list_price_ceil = ceil(list_price_raw)
    discount_sale_price = list_price_ceil * item.discount_rate

    ip_fee = break_even_raw * item.ip_rate
    platform_fee = break_even_raw * item.platform_rate
    promotion_fee = break_even_raw * item.promotion_rate
    team_fee = break_even_raw * item.team_rate
    output_vat = break_even_raw * vat_factor
    input_vat = item.tax_included_cost * vat_factor
    net_vat = (break_even_raw - item.tax_included_cost) * vat_factor
    target_profit_amount = break_even_raw * item.target_profit_rate
    profit_at_break_even = calculate_profit(break_even_raw, item)
    profit_at_discount = calculate_profit(discount_sale_price, item)
    discount_meets_target = profit_at_discount + 1e-9 >= discount_sale_price * item.target_profit_rate

    config = PRICING_TYPE_CONFIG[item.pricing_type]
    return PriceResult(
        product_name=item.product_name,
        pricing_type=item.pricing_type,
        pricing_type_label=str(config["label"]),
        total_cost_rate=total_rate,
        target_profit_rate=item.target_profit_rate,
        vat_factor=vat_factor,
        break_even_price_raw=break_even_raw,
        break_even_price_ceil=break_even_ceil,
        list_price_raw=list_price_raw,
        list_price_ceil=list_price_ceil,
        discount_sale_price=discount_sale_price,
        ip_fee=ip_fee,
        platform_fee=platform_fee,
        promotion_fee=promotion_fee,
        team_fee=team_fee,
        output_vat=output_vat,
        input_vat=input_vat,
        net_vat=net_vat,
        target_profit_amount=target_profit_amount,
        profit_at_break_even_raw=profit_at_break_even,
        profit_at_discount_sale=profit_at_discount,
        discount_break_even=discount_meets_target,
        suggested_price_tiers=suggest_price_tiers(list_price_ceil),
        formula_note=str(config["formula_note"]),
    )


def suggest_price_tiers(base_price: int) -> List[int]:
    """Return practical ecommerce list-price tiers at or above the computed base.

    The first value keeps the exact rounded-up price for traceability. The following
    values prefer common display endings such as x69/x79/x99 and avoid overly dense
    one-yuan increments.
    """
    candidates = {int(base_price)}
    hundred_floor = (base_price // 100) * 100
    endings = [9, 19, 29, 39, 49, 59, 69, 79, 99]
    for hundred in range(hundred_floor, hundred_floor + 401, 100):
        for ending in endings:
            price = hundred + ending
            if price >= base_price:
                candidates.add(price)
    ordered = [int(base_price)]
    for price in sorted(candidates):
        if price not in ordered:
            ordered.append(price)
        if len(ordered) >= 6:
            break
    return ordered
