#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import json
import math
import random
import struct
import zlib
from dataclasses import dataclass
from pathlib import Path

TARGET_TOTAL_SEATS = 5014
MAX_THRESHOLD_SEATS = 5000
GLOBAL_THRESHOLD_STEP = 100
PROBABILITY_DECIMALS = 3
COMPONENT_TOTAL_TOLERANCE = 0.05
END_DATE = "2026-05-07T21:00:00Z"
RESOLUTION_DATE = "2026-05-15T23:59:59Z"
MARKET_CODE = "ukloc26"
RNG_SEED = 20260507
SIM_DRAWS = 50000
REPO_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_OUTPUT_PATH = REPO_ROOT / "data/local-elections-2026/uk_local_elections_2026_multicount_example.csv"
DEFAULT_ASSETS_DIR = REPO_ROOT / "data/local-elections-2026"


@dataclass(frozen=True)
class PartySpec:
    key: str
    label: str
    color: str
    family: str
    current_poll_share: float
    baseline_poll_share: float
    baseline_seat_share: float
    defended_seats: int
    incumbent_conversion_rate: float
    government_party: bool = False


PARTIES: list[PartySpec] = [
    PartySpec("labour", "Labour", "#dc2626", "major", 0.34, 0.31, 0.37, 1050, 0.54),
    PartySpec("conservative", "Conservative", "#2563eb", "major", 0.22, 0.34, 0.22, 1680, 0.39, True),
    PartySpec("lib_dem", "Liberal Democrat", "#f59e0b", "progressive", 0.13, 0.14, 0.16, 620, 0.55),
    PartySpec("green", "Green", "#16a34a", "challenger", 0.10, 0.07, 0.09, 120, 0.60),
    PartySpec("reform", "Reform UK", "#06b6d4", "challenger", 0.15, 0.03, 0.08, 10, 0.45),
    PartySpec("other", "Other", "#6b7280", "challenger", 0.06, 0.11, 0.08, 1371, 0.51),
]
FAMILY_PRIORS = {"major": (0.00, 0.09), "progressive": (0.01, 0.08), "challenger": (0.00, 0.10)}


def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
    c = hex_color.lstrip("#")
    return int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16)


def clamp_probability(value: float) -> float:
    return min(0.999, max(0.001, value))


def seat_thresholds(max_threshold: int, step: int) -> list[int]:
    values = [1]
    current = max(step, 2)
    while current <= max_threshold:
        values.append(current)
        current += step
    if values[-1] != max_threshold:
        values.append(max_threshold)
    return sorted(set(values))


def gamma_sample(alpha: float, rng: random.Random) -> float:
    return rng.gammavariate(alpha, 1.0)


def dirichlet_sample(base_shares: list[float], concentration: float, rng: random.Random) -> list[float]:
    draws = [gamma_sample(max(0.05, concentration * p), rng) for p in base_shares]
    total = sum(draws)
    return [d / total for d in draws]


def apportion_to_integer_seats(total_seats: int, shares: list[float]) -> list[int]:
    raw = [s * total_seats for s in shares]
    floors = [int(math.floor(x)) for x in raw]
    remaining = total_seats - sum(floors)
    remainders = sorted(((raw[i] - floors[i], i) for i in range(len(raw))), key=lambda x: x[0], reverse=True)
    for _, idx in remainders[:remaining]:
        floors[idx] += 1
    return floors


def normalize_positive(values: list[float]) -> list[float]:
    s = sum(values)
    if s <= 0:
        return [1.0 / len(values)] * len(values)
    return [v / s for v in values]


def simulate_seat_draws(draws: int, rng_seed: int) -> dict[str, list[int]]:
    rng = random.Random(rng_seed)
    results = {p.key: [] for p in PARTIES}
    for _ in range(draws):
        national_cycle = rng.gauss(0.0, 0.07)
        turnout_shock = rng.gauss(0.0, 0.05)
        district_mix_shock = rng.gauss(0.0, 0.05)
        unscaled: list[float] = []
        for p in PARTIES:
            fam_mu, fam_sd = FAMILY_PRIORS[p.family]
            swing = (
                1.20 * (p.current_poll_share - p.baseline_poll_share)
                + 0.55 * (p.incumbent_conversion_rate - 0.5)
                + 0.35 * math.log1p(p.defended_seats / TARGET_TOTAL_SEATS)
                + (-0.06 if p.government_party else 0.0)
                + 0.80 * national_cycle
                + 0.45 * turnout_shock
                + 0.35 * district_mix_shock
                + rng.gauss(fam_mu, fam_sd)
                + rng.gauss(0.0, 0.06)
            )
            unscaled.append(max(1e-7, p.baseline_seat_share * math.exp(swing)))
        base_shares = normalize_positive(unscaled)
        noisy = dirichlet_sample(base_shares, max(120.0, 220.0 + rng.gauss(0.0, 18.0)), rng)
        seats = apportion_to_integer_seats(TARGET_TOTAL_SEATS, noisy)
        for p, s in zip(PARTIES, seats):
            results[p.key].append(s)
    return results


def survival_from_samples(samples: list[int], thresholds: list[int], decimals: int | None = None) -> list[float]:
    n = len(samples)
    out, prev = [], 0.999
    for t in thresholds:
        p = min(prev, clamp_probability(sum(1 for x in samples if x >= t) / n))
        out.append(round(p, decimals) if decimals is not None else p)
        prev = p
    return out


def threshold_widths(thresholds: list[int]) -> list[int]:
    widths: list[int] = []
    prev = 0
    for threshold in thresholds:
        width = max(0, int(threshold) - prev)
        if width <= 0:
            width = 1
        widths.append(width)
        prev = int(threshold)
    return widths


def enforce_monotone_nonincreasing(values: list[float]) -> list[float]:
    out: list[float] = []
    prev = 0.999
    for value in values:
        p = min(prev, clamp_probability(value))
        out.append(p)
        prev = p
    return out


def logit(probability: float) -> float:
    p = clamp_probability(probability)
    return math.log(p / (1.0 - p))


def logistic(value: float) -> float:
    if value >= 0:
        z = math.exp(-value)
        return 1.0 / (1.0 + z)
    z = math.exp(value)
    return z / (1.0 + z)


def round_probabilities(values: list[float], decimals: int) -> list[float]:
    return enforce_monotone_nonincreasing([round(v, decimals) for v in values])


def component_expected_seats_from_probs(probabilities: list[float], widths: list[int]) -> float:
    return sum(w * p for w, p in zip(widths, probabilities))


def component_expected_total(
    probs_by_party: dict[str, list[float]],
    widths: list[int],
) -> float:
    total = 0.0
    for probabilities in probs_by_party.values():
        total += component_expected_seats_from_probs(probabilities, widths)
    return total


def shift_probs_with_lambda(
    base_logits_by_party: dict[str, list[float]],
    lambda_shift: float,
    decimals: int | None = None,
) -> dict[str, list[float]]:
    shifted: dict[str, list[float]] = {}
    for party_key, logits in base_logits_by_party.items():
        next_probs = enforce_monotone_nonincreasing(
            [logistic(logit_value + lambda_shift) for logit_value in logits]
        )
        if decimals is not None:
            next_probs = round_probabilities(next_probs, decimals)
        shifted[party_key] = next_probs
    return shifted


def calibrate_probs_to_component_total(
    raw_probs_by_party: dict[str, list[float]],
    widths: list[int],
    target_total: float,
    decimals: int,
) -> tuple[dict[str, list[float]], float]:
    base_logits_by_party = {
        party_key: [logit(probability) for probability in probabilities]
        for party_key, probabilities in raw_probs_by_party.items()
    }

    def rounded_total_for_lambda(lambda_shift: float) -> tuple[float, dict[str, list[float]]]:
        shifted = shift_probs_with_lambda(base_logits_by_party, lambda_shift, decimals=decimals)
        return component_expected_total(shifted, widths), shifted

    lo, hi = -24.0, 24.0
    lo_total, _ = rounded_total_for_lambda(lo)
    hi_total, _ = rounded_total_for_lambda(hi)
    if target_total < lo_total or target_total > hi_total:
        raise ValueError(
            f"Target component total {target_total} outside feasible range [{lo_total:.3f}, {hi_total:.3f}]"
        )

    for _ in range(80):
        mid = (lo + hi) / 2.0
        mid_total, _ = rounded_total_for_lambda(mid)
        if mid_total < target_total:
            lo = mid
        else:
            hi = mid

    best_probs: dict[str, list[float]] | None = None
    best_total = 0.0
    best_error = float("inf")
    for lambda_candidate in (lo, hi, (lo + hi) / 2.0):
        candidate_total, candidate_probs = rounded_total_for_lambda(lambda_candidate)
        candidate_error = abs(candidate_total - target_total)
        if candidate_error < best_error:
            best_error = candidate_error
            best_total = candidate_total
            best_probs = candidate_probs

    if best_probs is None:
        raise RuntimeError("Component total calibration failed to produce probabilities")
    return best_probs, best_total


def expected_seats(samples: list[int]) -> float:
    return sum(samples) / len(samples)


def quantile(samples: list[int], q: float) -> int:
    vals = sorted(samples)
    return vals[min(len(vals) - 1, max(0, int((len(vals) - 1) * q)))]


def new_canvas(w: int, h: int, bg: tuple[int, int, int] = (255, 255, 255)) -> list[list[tuple[int, int, int]]]:
    return [[bg for _ in range(w)] for _ in range(h)]


def set_px(img, x, y, c):
    if 0 <= y < len(img) and 0 <= x < len(img[0]):
        img[y][x] = c


def draw_line(img, x0, y0, x1, y1, c):
    dx, dy = abs(x1 - x0), abs(y1 - y0)
    sx = 1 if x0 < x1 else -1
    sy = 1 if y0 < y1 else -1
    err = dx - dy
    while True:
        set_px(img, x0, y0, c)
        if x0 == x1 and y0 == y1:
            break
        e2 = 2 * err
        if e2 > -dy:
            err -= dy
            x0 += sx
        if e2 < dx:
            err += dx
            y0 += sy


def fill_rect(img, x0, y0, x1, y1, c):
    xa, xb = sorted((max(0, x0), min(len(img[0]) - 1, x1)))
    ya, yb = sorted((max(0, y0), min(len(img) - 1, y1)))
    for y in range(ya, yb + 1):
        for x in range(xa, xb + 1):
            img[y][x] = c


def write_png(path: Path, img) -> None:
    h, w = len(img), len(img[0])
    raw = bytearray()
    for row in img:
        raw.append(0)
        for r, g, b in row:
            raw.extend((r, g, b))

    def chunk(tag: bytes, data: bytes) -> bytes:
        return struct.pack("!I", len(data)) + tag + data + struct.pack("!I", zlib.crc32(tag + data) & 0xFFFFFFFF)

    png = bytearray(b"\x89PNG\r\n\x1a\n")
    png.extend(chunk(b"IHDR", struct.pack("!IIBBBBB", w, h, 8, 2, 0, 0, 0)))
    png.extend(chunk(b"IDAT", zlib.compress(bytes(raw), 9)))
    png.extend(chunk(b"IEND", b""))
    path.write_bytes(bytes(png))


def render_expected_seats_png(path: Path, expected_by_party: dict[str, float]) -> None:
    w, h = 1100, 520
    ml, mr, mt, mb = 80, 30, 30, 80
    pw, ph = w - ml - mr, h - mt - mb
    img = new_canvas(w, h)
    axis = (40, 40, 40)
    fill_rect(img, 0, 0, w - 1, h - 1, (255, 255, 255))
    draw_line(img, ml, mt + ph, ml + pw, mt + ph, axis)
    draw_line(img, ml, mt, ml, mt + ph, axis)
    maxv = max(expected_by_party.values())
    bw = int((pw / len(PARTIES)) * 0.6)
    for i, p in enumerate(PARTIES):
        xc = int(ml + (i + 0.5) * (pw / len(PARTIES)))
        bh = int((expected_by_party[p.label] / maxv) * (ph * 0.95))
        fill_rect(img, xc - bw // 2, mt + ph - bh, xc + bw // 2, mt + ph - 1, hex_to_rgb(p.color))
    write_png(path, img)


def render_survival_png(path: Path, thresholds: list[int], probs_by_party: dict[str, list[float]]) -> None:
    w, h = 1100, 560
    ml, mr, mt, mb = 80, 30, 30, 70
    pw, ph = w - ml - mr, h - mt - mb
    img = new_canvas(w, h)
    axis, grid = (40, 40, 40), (230, 230, 230)
    draw_line(img, ml, mt + ph, ml + pw, mt + ph, axis)
    draw_line(img, ml, mt, ml, mt + ph, axis)

    max_threshold = max(thresholds) if thresholds else MAX_THRESHOLD_SEATS

    def sx(t): return int(ml + (t / max_threshold) * pw)
    def sy(p): return int(mt + (1.0 - p) * ph)

    for ytick in [0.25, 0.5, 0.75]:
        y = sy(ytick)
        draw_line(img, ml, y, ml + pw, y, grid)

    for party in PARTIES:
        c = hex_to_rgb(party.color)
        pr = probs_by_party[party.key]
        for i in range(len(thresholds) - 1):
            draw_line(img, sx(thresholds[i]), sy(pr[i]), sx(thresholds[i + 1]), sy(pr[i + 1]), c)
    write_png(path, img)


def write_assets(
    assets_dir: Path,
    thresholds: list[int],
    probs_by_party: dict[str, list[float]],
    expected_by_party_component: dict[str, float],
    expected_by_party_simulation: dict[str, float],
    interval90_by_party: dict[str, tuple[int, int]],
    component_expected_total_seats: float,
) -> None:
    assets_dir.mkdir(parents=True, exist_ok=True)
    render_expected_seats_png(assets_dir / "expected_seats.png", expected_by_party_component)
    render_survival_png(assets_dir / "survival_curves.png", thresholds, probs_by_party)
    summary = {
        "model": "swing-first hierarchical simulation",
        "draws": SIM_DRAWS,
        "seed": RNG_SEED,
        "total_seats": TARGET_TOTAL_SEATS,
        "max_threshold": MAX_THRESHOLD_SEATS,
        "component_expected_total_seats": component_expected_total_seats,
        "component_expected_total_delta": TARGET_TOTAL_SEATS - component_expected_total_seats,
        "expected_seats": expected_by_party_component,
        "expected_seats_component": expected_by_party_component,
        "expected_seats_simulation": expected_by_party_simulation,
        "interval90": {k: [v[0], v[1]] for k, v in interval90_by_party.items()},
    }
    (assets_dir / "model_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")


def write_csv(path: Path, assets_dir: Path) -> None:
    thresholds = seat_thresholds(MAX_THRESHOLD_SEATS, GLOBAL_THRESHOLD_STEP)
    widths = threshold_widths(thresholds)
    draws = simulate_seat_draws(SIM_DRAWS, RNG_SEED)
    party_rows: list[dict[str, str]] = []
    expected_by_party_simulation: dict[str, float] = {}
    interval90_by_party: dict[str, tuple[int, int]] = {}
    raw_probs_by_party: dict[str, list[float]] = {}

    for spec in PARTIES:
        samples = draws[spec.key]
        probs = survival_from_samples(samples, thresholds)
        raw_probs_by_party[spec.key] = probs
        expected_by_party_simulation[spec.label] = expected_seats(samples)
        interval90_by_party[spec.label] = (quantile(samples, 0.05), quantile(samples, 0.95))

    probs_by_party, component_expected_total_seats = calibrate_probs_to_component_total(
        raw_probs_by_party,
        widths,
        TARGET_TOTAL_SEATS,
        decimals=PROBABILITY_DECIMALS,
    )

    expected_by_party_component: dict[str, float] = {}
    for spec in PARTIES:
        probs = probs_by_party[spec.key]
        expected_by_party_component[spec.label] = component_expected_seats_from_probs(probs, widths)
        for threshold, prob in zip(thresholds, probs):
            party_rows.append({
                "projection_group": spec.key,
                "threshold_decimal": str(threshold),
                "threshold_date": "",
                "initial_probability": f"{prob:.{PROBABILITY_DECIMALS}f}",
                "label": f"{spec.label} >= {threshold} seats",
                "end_date": END_DATE,
                "decay_rate": "0.010",
                "status": "open",
            })

    svelte_params = {
        "scaleType": "linear", "scaleBase": 1, "timeCadence": "yearly", "countUnitDisplay": "axis",
        "multicount": {"enabled": True, "totalSeats": TARGET_TOTAL_SEATS, "parties": [{"key": p.key, "label": p.label, "color": p.color} for p in PARTIES]},
    }

    metadata = [
        ("code", MARKET_CODE), ("title", "England local elections (May 2026): councillor seats by party"), ("type", "count"),
        ("visibility", "public"), ("status", "draft"), ("budget", "300"), ("decay_rate", "0.010"),
        ("resolution_criteria", "Resolves to final seat totals by party across the 134 English councils voting on 7 May 2026 (ordinary local elections only; excludes Surrey shadow authority elections). Source hierarchy: official local authority declarations then Open Council Data UK and BBC election result roundups."),
        ("x_unit", "X seats"), ("number_format", ",.0f"), ("end_date", END_DATE), ("resolution_date", RESOLUTION_DATE),
        ("cumulative", "false"), ("background_info_path", "background_info.html"),
        ("svelte_params", json.dumps(svelte_params, separators=(",", ":"))),
    ]

    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as handle:
        for key, value in metadata:
            handle.write(f"# market_{key}: {value}\n")
        simulation_total_expected = sum(expected_by_party_simulation.values())
        component_total_delta = TARGET_TOTAL_SEATS - component_expected_total_seats
        handle.write(
            "# generated_note: "
            f"swing_mlm_draws={SIM_DRAWS}, seed={RNG_SEED}, threshold_step={GLOBAL_THRESHOLD_STEP}, "
            f"max_threshold={MAX_THRESHOLD_SEATS}, "
            f"simulation_expected_total={simulation_total_expected:.3f}, "
            f"component_expected_total={component_expected_total_seats:.3f}, "
            f"component_total_delta={component_total_delta:.3f}\n"
        )
        for label, value in expected_by_party_component.items():
            low, high = interval90_by_party[label]
            slug = label.lower().replace(" ", "_")
            handle.write(f"# generated_expected_{slug}: {value:.1f}\n")
            handle.write(f"# generated_expected_simulation_{slug}: {expected_by_party_simulation[label]:.1f}\n")
            handle.write(f"# generated_interval90_{slug}: [{low},{high}]\n")
        handle.write("\n")
        writer = csv.DictWriter(handle, fieldnames=["projection_group", "threshold_decimal", "threshold_date", "initial_probability", "label", "end_date", "decay_rate", "status"], lineterminator="\n")
        writer.writeheader()
        writer.writerows(party_rows)

    if abs(component_total_delta) > COMPONENT_TOTAL_TOLERANCE:
        raise RuntimeError(
            f"Component expected total mismatch too large: delta={component_total_delta:.3f} seats"
        )

    write_assets(
        assets_dir,
        thresholds,
        probs_by_party,
        expected_by_party_component,
        expected_by_party_simulation,
        interval90_by_party,
        component_expected_total_seats,
    )
    print(f"Wrote {path}")
    print(f"Wrote assets in {assets_dir}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate multicount local-election market CSV")
    parser.add_argument("--output", default=str(DEFAULT_OUTPUT_PATH), help="Output CSV path")
    parser.add_argument("--assets-dir", default=str(DEFAULT_ASSETS_DIR), help="Directory for generated chart assets")
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    write_csv(Path(args.output).resolve(), Path(args.assets_dir).resolve())
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
