#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import html
import json
import math
import os
import ssl
from dataclasses import dataclass
from datetime import date, datetime, timedelta, timezone
from pathlib import Path
from statistics import NormalDist
from typing import Any
from urllib import parse, request

import matplotlib
import numpy as np
import pandas as pd

matplotlib.use("Agg")
import matplotlib.pyplot as plt


REPO_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_DATA_DIR = REPO_ROOT / "data" / "hormuz-transit-calls"
SCRIPT_PATH = Path(__file__).resolve()
DEFAULT_CSV_PATH = DEFAULT_DATA_DIR / "market.csv"

PORTWATCH_QUERY_URL = (
    "https://services9.arcgis.com/weJ1QsnbMYJlCHdG/arcgis/rest/services/"
    "Daily_Chokepoints_Data/FeatureServer/0/query"
)
PORTWATCH_PAGE_URL = "https://portwatch.imf.org/pages/cb5856222a5b4105adc6ee7e880a1730"
JCG_TSUGARU_NAV_GUIDE_URL = (
    "https://msi.nga.mil/apology_objects/Pub120bk.pdf"
)
AP_JAPAN_SNOW_URL = "https://apnews.com/article/japan-heavy-snow-flight-train-cancellation-04010da7e4512af0b7f1a0e92840d2bd"
AP_JAPAN_HOKKAIDO_SNOW_2025_URL = "https://apnews.com/article/japan-snow-hokkaido-dc95bdfed1045b8970570bc0097e0b1f"
AP_RED_SEA_REROUTE_URL = "https://apnews.com/article/china-red-sea-tensions-global-trade-houthis-shipping-4f398a2e9d4143cd10ee729bda129be2"
AP_RED_SEA_UN_TRADE_URL = "https://apnews.com/article/un-global-trade-red-sea-ukraine-panama-440507f5e8b0b6961f611274b6b19269"
AP_KERCH_TRAFFIC_URL = "https://apnews.com/article/russia-ukraine-war-crimea-bridge-be0a8f3cb98c278aa4cbb14f8588cae4"
AP_SUEZ_EVERGIVEN_2023_URL = "https://apnews.com/article/copenhagen-suez-canal-egypt-business-e464615cbec0641e7ef4b8c9a721ac54"
AP_TAIWAN_DRILLS_2024_URL = "https://apnews.com/article/taiwan-china-military-drills-053fbadaf1125be4adfee7a7e6d1ce79"
REUTERS_BOSPHORUS_2022_URL = (
    "https://www.yahoo.com/news/tanker-jam-turkeys-bosphorus-persists-153946337.html"
)
IMF_COVID_SHIPPING_WP_URL = "https://www.imf.org/en/Publications/WP/Issues/2020/12/18/Supply-Spillovers-During-the-Pandemic-Evidence-from-High-Frequency-Shipping-Data-49966"
UNCTAD_COVID_MARITIME_2020_URL = "https://unctad.org/news/covid-19-cut-maritime-trade-42-first-half-2020"

DEFAULT_MARKET_CODE_MONTHLY = "hormuz-calls26"
DEFAULT_MARKET_CODE_DAILY = "hormuz-calls26-daily"
DEFAULT_MARKET_CODE_WEEKLY = "hormuz-weekly"
DEFAULT_MARKET_TITLE_MONTHLY = "Monthly average Strait of Hormuz daily transit calls"
DEFAULT_MARKET_TITLE_DAILY = "Strait of Hormuz daily transit calls"
DEFAULT_MARKET_TITLE_WEEKLY = "Strait of Hormuz weekly transit calls"
MARKET_BUDGET = 500
MARKET_DECAY_RATE = 0.005
PROBABILITY_DECIMALS = 3

TARGET_PORT_ID = "chokepoint6"
TARGET_PORT_NAME = "Strait of Hormuz"
EVENT_START_DATE = datetime(2026, 3, 1, tzinfo=timezone.utc)
FORECAST_END_DATE = date(2026, 12, 31)

# Daily call-count thresholds for each target date distribution.
COUNT_THRESHOLDS = [
    0,
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    10,
    12,
    15,
    18,
    22,
    27,
    33,
    40,
    50,
    60,
    75,
    90,
    110,
    130,
]
WEEKLY_SUM_THRESHOLDS = [threshold * 7 for threshold in COUNT_THRESHOLDS if threshold > 0]

# Episode detection for reference classes
SHOCK_BASELINE_WINDOW_DAYS = 28
SHOCK_MIN_BASELINE = 20.0
SHOCK_START_RATIO = 0.50
RECOVERY_RATIO = 0.70
RECOVERY_CONSECUTIVE_DAYS = 14
SEVERE_GLOBAL_RATIO = 0.20
SEVERE_CONFLICT_RATIO = 0.35

CONFLICT_PROXY_PORTS = {
    "chokepoint1",   # Suez Canal
    "chokepoint3",   # Bosporus Strait
    "chokepoint4",   # Bab el-Mandeb Strait
    "chokepoint28",  # Kerch Strait
}
# Ports excluded from fitted disruption classes when ratio-based drop detection
# can be dominated by diversion-route volatility rather than true shutdown dynamics.
EXCLUDED_FROM_FIT_PORTS = {
    "chokepoint7",  # Cape of Good Hope
}

# Reference-class mode for this market build.
# "conflict_only" limits recovery and attenuation fitting to conflict-affected episodes.
REFERENCE_CLASS_MODE = "conflict_only"

# Partial pooling parameters for recovery-time model
POOLING_TAU = 0.55
HORMUZ_SHIFT_MULTIPLIER = 2.5
HORMUZ_SIGMA_FLOOR = 1.10
EFFECTIVE_SAMPLE_FRACTION_WITHIN_PORT = 0.35
WEIGHT_HORMUZ = 0.60
WEIGHT_CONFLICT = 0.23
WEIGHT_GLOBAL_SEVERE = 0.12
WEIGHT_GLOBAL_STRESS = 0.05
WEIGHT_EXTERNAL_ANALOGUES = 0.04
EXTERNAL_OVERLAP_DISCOUNT = 0.35
WEIGHT_CONFLICT_ONLY_HORMUZ = 0.62
WEIGHT_CONFLICT_ONLY_CONFLICT = 0.33
WEIGHT_CONFLICT_ONLY_EXTERNAL = 0.05

# Post-recovery attenuation: recovered traffic can remain below full baseline.
ATTENUATION_POOLING_TAU = 0.18
ATTENUATION_RATIO_MIN = 0.70
ATTENUATION_RATIO_MAX = 1.08
ATTENUATION_RATIO_FLOOR = 0.85
ATTENUATION_RATIO_CAP = 1.02
ATTENUATION_DECAY_FLOOR = 0.001
ATTENUATION_DECAY_CAP = 0.08
ATTENUATION_HORIZONS_DAYS = (30, 60, 90)

EXTERNAL_ANALOGUES = [
    {
        "key": "tanker_war_1984_1988",
        "name": "Tanker War (1984-1988)",
        "start_date": date(1984, 5, 13),
        "end_date": date(1988, 8, 20),
        "source_label": "Wikipedia: Tanker War",
        "source_url": "https://en.wikipedia.org/wiki/Tanker_War",
        "overlap_with_portwatch": False,
        "conflict_related": True,
    },
    {
        "key": "red_sea_bab_after_houthi",
        "name": "Red Sea / Bab el-Mandeb after Houthi attacks",
        "start_date": date(2023, 11, 19),
        "end_date": date(2025, 12, 31),
        "source_label": "AP: Red Sea attacks forced rerouting around Africa",
        "source_url": AP_RED_SEA_REROUTE_URL,
        "overlap_with_portwatch": True,
        "conflict_related": True,
    },
    {
        "key": "black_sea_after_2022_invasion",
        "name": "Black Sea after 2022 invasion",
        "start_date": date(2022, 2, 24),
        "end_date": date(2023, 8, 10),
        "source_label": "AP: temporary Black Sea corridor after grain-deal exit",
        "source_url": "https://apnews.com/article/855430e8a89444daa5ba1bafa557c280",
        "overlap_with_portwatch": True,
        "conflict_related": True,
    },
    {
        "key": "gulf_tanker_crisis_2019",
        "name": "2019 Gulf tanker crisis",
        "start_date": date(2019, 5, 12),
        "end_date": date(2020, 1, 3),
        "source_label": "AP: reference to 2019 limpet-mine tanker attacks",
        "source_url": "https://apnews.com/article/16f0c69a2314c17dcc08149b10861fa3",
        "overlap_with_portwatch": False,
        "conflict_related": True,
    },
    {
        "key": "ever_given_suez_blockage_2021",
        "name": "Ever Given / Suez blockage (2021)",
        "start_date": date(2021, 3, 23),
        "end_date": date(2021, 3, 29),
        "source_label": "AP: Ever Given blockage and legal aftermath",
        "source_url": AP_SUEZ_EVERGIVEN_2023_URL,
        "overlap_with_portwatch": True,
        "conflict_related": False,
    },
    {
        "key": "panama_drought_restrictions",
        "name": "Panama drought restrictions",
        "start_date": date(2023, 8, 1),
        "end_date": date(2024, 4, 30),
        "source_label": "AP: Panama Canal traffic cut by more than a third",
        "source_url": "https://apnews.com/article/bd76a77825a2e8e751a24346f8fd54a9",
        "overlap_with_portwatch": False,
        "conflict_related": False,
    },
]

# Regime windows for call-count distributions
BASELINE_START = datetime(2026, 1, 15, tzinfo=timezone.utc)
BASELINE_END = datetime(2026, 2, 28, tzinfo=timezone.utc)
COLLAPSE_START = datetime(2026, 3, 2, tzinfo=timezone.utc)
COLLAPSE_END = datetime(2026, 3, 8, tzinfo=timezone.utc)


def build_ssl_context() -> ssl.SSLContext:
    cafile_env = os.getenv("SSL_CERT_FILE")
    if cafile_env and Path(cafile_env).exists():
        return ssl.create_default_context(cafile=cafile_env)

    try:
        import certifi  # type: ignore

        return ssl.create_default_context(cafile=certifi.where())
    except Exception:
        return ssl.create_default_context()


SSL_CONTEXT = build_ssl_context()


@dataclass(frozen=True)
class Episode:
    portid: str
    portname: str
    start: datetime
    end: datetime | None
    duration_days: int
    min_ratio: float
    censored: bool


@dataclass(frozen=True)
class RecoveryModelParams:
    mu_global_severe: float
    sigma_global_severe: float
    mu_global_stress: float
    sigma_global_stress: float
    mu_external: float
    sigma_external: float
    mu_conflict: float
    sigma_conflict: float
    mu_hormuz: float
    sigma_hormuz: float
    weight_hormuz: float
    weight_conflict: float
    weight_global_severe: float
    weight_global_stress: float
    weight_external: float


@dataclass(frozen=True)
class AttenuationModelParams:
    ratio30_global: float
    ratio30_conflict: float
    ratio30_hormuz: float
    ratio60_global: float
    ratio60_conflict: float
    ratio60_hormuz: float
    ratio90_global: float
    ratio90_conflict: float
    ratio90_hormuz: float
    decay_per_day: float
    half_life_days: float


@dataclass(frozen=True)
class RegimeParams:
    baseline_mean: float
    baseline_var: float
    collapse_mean: float
    collapse_var: float
    baseline_size: float
    collapse_size: float


def fetch_features(where: str, out_fields: list[str]) -> pd.DataFrame:
    rows: list[dict[str, Any]] = []
    offset = 0

    while True:
        params = {
            "where": where,
            "outFields": ",".join(out_fields),
            "orderByFields": "portid,date",
            "resultOffset": str(offset),
            "resultRecordCount": "1000",
            "f": "json",
        }
        url = f"{PORTWATCH_QUERY_URL}?{parse.urlencode(params)}"
        with request.urlopen(url, context=SSL_CONTEXT) as resp:
            payload = json.loads(resp.read().decode("utf-8"))

        features = payload.get("features", [])
        if not features:
            break

        rows.extend(feature["attributes"] for feature in features)
        offset += len(features)

    frame = pd.DataFrame(rows)
    frame["date"] = pd.to_datetime(frame["date"], unit="ms", utc=True)
    return frame


def is_suez_evergiven_non_conflict(ep: Episode) -> bool:
    if ep.portid != "chokepoint1":
        return False
    start_d = ep.start.date()
    return date(2021, 3, 20) <= start_d <= date(2021, 4, 20)


def build_ratio_frame(df: pd.DataFrame) -> pd.DataFrame:
    parts: list[pd.DataFrame] = []

    for _, group in df.sort_values(["portid", "date"]).groupby("portid"):
        g = group.sort_values("date").copy()
        g["baseline"] = (
            g["n_total"].rolling(SHOCK_BASELINE_WINDOW_DAYS, min_periods=14).median().shift(1)
        )
        g["ratio"] = g["n_total"] / g["baseline"]
        g = g.dropna(subset=["baseline"])
        if not g.empty:
            parts.append(g[["portid", "date", "baseline", "ratio"]])

    if not parts:
        return pd.DataFrame(columns=["portid", "date", "baseline", "ratio"])
    return pd.concat(parts, ignore_index=True)


def detect_episodes(df: pd.DataFrame) -> list[Episode]:
    episodes: list[Episode] = []

    for portid, group in df.sort_values(["portid", "date"]).groupby("portid"):
        g = group.sort_values("date").copy()
        g["baseline"] = (
            g["n_total"].rolling(SHOCK_BASELINE_WINDOW_DAYS, min_periods=14).median().shift(1)
        )
        g["ratio"] = g["n_total"] / g["baseline"]
        g = g.dropna(subset=["baseline"])
        if g.empty:
            continue

        portname = str(g["portname"].iloc[0])
        in_episode = False
        episode_start: datetime | None = None

        for _, row in g.iterrows():
            baseline = float(row["baseline"])
            ratio = float(row["ratio"])
            current_date = row["date"].to_pydatetime()

            if not in_episode and baseline >= SHOCK_MIN_BASELINE and ratio < SHOCK_START_RATIO:
                in_episode = True
                episode_start = current_date
                continue

            if not in_episode or episode_start is None:
                continue

            window = g[g["date"] <= row["date"]].tail(RECOVERY_CONSECUTIVE_DAYS)
            recovered = False
            if len(window) == RECOVERY_CONSECUTIVE_DAYS:
                with np.errstate(divide="ignore", invalid="ignore"):
                    ratios = window["n_total"].to_numpy(dtype=float) / window["baseline"].to_numpy(dtype=float)
                recovered = bool(np.all(ratios >= RECOVERY_RATIO))

            if recovered:
                segment = g[(g["date"] >= episode_start) & (g["date"] <= row["date"])]
                end_date = current_date
                duration_days = (end_date - episode_start).days + 1
                min_ratio = float(segment["ratio"].min())
                episodes.append(
                    Episode(
                        portid=portid,
                        portname=portname,
                        start=episode_start,
                        end=end_date,
                        duration_days=duration_days,
                        min_ratio=min_ratio,
                        censored=False,
                    )
                )
                in_episode = False
                episode_start = None

        if in_episode and episode_start is not None:
            end_date = g["date"].iloc[-1].to_pydatetime()
            segment = g[g["date"] >= episode_start]
            duration_days = (end_date - episode_start).days + 1
            min_ratio = float(segment["ratio"].min())
            episodes.append(
                Episode(
                    portid=portid,
                    portname=portname,
                    start=episode_start,
                    end=None,
                    duration_days=duration_days,
                    min_ratio=min_ratio,
                    censored=True,
                )
            )

    return episodes


def effective_sample_size(eps: list[Episode]) -> float:
    n = float(len(eps))
    if n <= 0:
        return 0.0
    n_ports = float(len({ep.portid for ep in eps}))
    return n_ports + EFFECTIVE_SAMPLE_FRACTION_WITHIN_PORT * (n - n_ports)


def post_recovery_ratio_samples(
    eps: list[Episode],
    ratio_df: pd.DataFrame,
    horizon_days: int,
) -> list[tuple[Episode, float]]:
    min_obs = max(12, int(math.ceil(0.8 * float(horizon_days))))
    samples: list[tuple[Episode, float]] = []

    for ep in eps:
        if ep.end is None:
            continue
        start_post = ep.end + timedelta(days=1)
        end_post = ep.end + timedelta(days=horizon_days)
        sub = ratio_df[
            (ratio_df["portid"] == ep.portid)
            & (ratio_df["date"] >= start_post)
            & (ratio_df["date"] <= end_post)
        ]
        if len(sub) < min_obs:
            continue
        val = float(sub["ratio"].mean())
        if not math.isfinite(val):
            continue
        samples.append((ep, min(ATTENUATION_RATIO_MAX, max(ATTENUATION_RATIO_MIN, val))))

    return samples


def attenuation_multiplier(days_since_recovery: float, params: AttenuationModelParams) -> float:
    d = max(0.0, float(days_since_recovery))
    r30 = params.ratio30_hormuz
    if r30 >= 0.999:
        return 1.0
    if d <= 30.0:
        return min(ATTENUATION_RATIO_CAP, max(ATTENUATION_RATIO_FLOOR, r30))

    gap30 = max(1e-6, 1.0 - r30)
    ratio = 1.0 - gap30 * math.exp(-params.decay_per_day * (d - 30.0))
    return min(ATTENUATION_RATIO_CAP, max(ATTENUATION_RATIO_FLOOR, ratio))


def expected_recovered_multiplier(
    days_since_event: int,
    known_survival_days: int,
    recovery_params: RecoveryModelParams,
    attenuation_params: AttenuationModelParams,
) -> float:
    p_recovered = conditioned_recovery_cdf(
        float(days_since_event),
        known_survival_days=known_survival_days,
        params=recovery_params,
    )
    if p_recovered <= 1e-9:
        return attenuation_multiplier(0.0, attenuation_params)

    weighted = 0.0
    prev = conditioned_recovery_cdf(
        float(known_survival_days),
        known_survival_days=known_survival_days,
        params=recovery_params,
    )
    for recovery_day in range(known_survival_days + 1, days_since_event + 1):
        cur = conditioned_recovery_cdf(
            float(recovery_day),
            known_survival_days=known_survival_days,
            params=recovery_params,
        )
        mass = max(0.0, cur - prev)
        prev = cur
        weighted += mass * attenuation_multiplier(float(days_since_event - recovery_day), attenuation_params)

    mean_ratio = weighted / max(1e-9, p_recovered)
    return min(ATTENUATION_RATIO_CAP, max(ATTENUATION_RATIO_FLOOR, mean_ratio))


def fit_recovery_model(
    episodes: list[Episode],
    ratio_df: pd.DataFrame,
) -> tuple[RecoveryModelParams, AttenuationModelParams, dict[str, Any]]:
    uncensored = [ep for ep in episodes if not ep.censored]
    fit_eligible = [ep for ep in uncensored if ep.portid not in EXCLUDED_FROM_FIT_PORTS]
    excluded_fit_eps = [ep for ep in uncensored if ep.portid in EXCLUDED_FROM_FIT_PORTS]

    suez_non_conflict_eps = [ep for ep in fit_eligible if is_suez_evergiven_non_conflict(ep)]

    conflict_affected_eps = [
        ep
        for ep in fit_eligible
        if ep.portid in CONFLICT_PROXY_PORTS
        and not is_suez_evergiven_non_conflict(ep)
        and ep.min_ratio <= SEVERE_CONFLICT_RATIO
    ]
    conflict_only_mode = REFERENCE_CLASS_MODE == "conflict_only"
    severe_non_conflict_eps: list[Episode] = []
    stress_non_conflict_eps: list[Episode] = []
    severe_reference_eps: list[Episode] = []

    if conflict_only_mode:
        severe_reference_eps = list(conflict_affected_eps)
    else:
        severe_non_conflict_eps = [
            ep
            for ep in fit_eligible
            if ep.portid not in CONFLICT_PROXY_PORTS and ep.min_ratio <= SEVERE_GLOBAL_RATIO
        ]
        stress_non_conflict_eps = [
            ep
            for ep in fit_eligible
            if ep.portid not in CONFLICT_PROXY_PORTS and SEVERE_GLOBAL_RATIO < ep.min_ratio <= SEVERE_CONFLICT_RATIO
        ]
        stress_non_conflict_eps.extend(suez_non_conflict_eps)

        # Fallback only if non-conflict severe class is too sparse.
        severe_all_eps = [ep for ep in fit_eligible if ep.min_ratio <= SEVERE_GLOBAL_RATIO]
        severe_reference_eps = (
            severe_non_conflict_eps if len(severe_non_conflict_eps) >= 5 else severe_all_eps
        )

    if len(severe_reference_eps) < 5:
        raise RuntimeError("Insufficient reference-class episodes for model fitting")
    if len(conflict_affected_eps) < 5:
        raise RuntimeError("Insufficient conflict-affected episodes for model fitting")

    def fit_class_stats(eps: list[Episode], sigma_floor: float) -> tuple[float, float, float]:
        vals = np.log(np.array([float(ep.duration_days) for ep in eps], dtype=float))
        mu = float(np.mean(vals))
        sigma_raw = float(np.std(vals, ddof=1)) if len(vals) > 1 else sigma_floor
        n_eff = max(1.0, effective_sample_size(eps))
        inflation = math.sqrt(max(1.0, float(len(vals)) / n_eff))
        sigma = max(sigma_floor, sigma_raw * inflation)
        return mu, sigma, n_eff

    if conflict_only_mode:
        mu_conflict, sigma_conflict, n_eff_conflict = fit_class_stats(
            conflict_affected_eps, sigma_floor=0.30
        )
        mu_global_severe = mu_conflict
        sigma_global_severe = sigma_conflict
        n_eff_global_severe = n_eff_conflict
        mu_global_stress = mu_conflict
        sigma_global_stress = sigma_conflict
        n_eff_global_stress = 0.0
        has_stress_class = False
    else:
        mu_global_severe, sigma_global_severe, n_eff_global_severe = fit_class_stats(
            severe_reference_eps, sigma_floor=0.30
        )

        conflict_logs = np.log(
            np.array([float(ep.duration_days) for ep in conflict_affected_eps], dtype=float)
        )
        mean_conflict = float(np.mean(conflict_logs))
        n_eff_conflict = max(1.0, effective_sample_size(conflict_affected_eps))

        sigma2 = sigma_global_severe * sigma_global_severe
        tau2 = POOLING_TAU * POOLING_TAU
        post_var = 1.0 / ((n_eff_conflict / sigma2) + (1.0 / tau2))
        mu_conflict = post_var * (
            (n_eff_conflict * mean_conflict / sigma2) + (mu_global_severe / tau2)
        )
        sigma_conflict_base = math.sqrt(sigma2 + post_var)
        sigma_conflict = max(
            0.30,
            sigma_conflict_base
            * math.sqrt(max(1.0, float(len(conflict_logs)) / n_eff_conflict)),
        )

        has_stress_class = len(stress_non_conflict_eps) >= 5
        if has_stress_class:
            mu_global_stress, sigma_global_stress, n_eff_global_stress = fit_class_stats(
                stress_non_conflict_eps, sigma_floor=0.35
            )
        else:
            mu_global_stress = mu_global_severe
            sigma_global_stress = max(0.35, sigma_global_severe)
            n_eff_global_stress = 0.0

    selected_external_analogues = (
        [item for item in EXTERNAL_ANALOGUES if bool(item.get("conflict_related"))]
        if conflict_only_mode
        else list(EXTERNAL_ANALOGUES)
    )
    if not selected_external_analogues:
        selected_external_analogues = list(EXTERNAL_ANALOGUES)

    ext_durations = np.array(
        [
            float((item["end_date"] - item["start_date"]).days + 1)
            for item in selected_external_analogues
        ],
        dtype=float,
    )
    ext_weights = np.array(
        [
            EXTERNAL_OVERLAP_DISCOUNT if bool(item.get("overlap_with_portwatch")) else 1.0
            for item in selected_external_analogues
        ],
        dtype=float,
    )
    ext_logs = np.log(ext_durations)
    mu_external = float(np.average(ext_logs, weights=ext_weights))
    var_external = float(np.average((ext_logs - mu_external) ** 2, weights=ext_weights))
    sigma_external_raw = math.sqrt(max(1e-9, var_external))
    n_eff_external = max(1.0, float((ext_weights.sum() ** 2) / np.square(ext_weights).sum()))
    sigma_external = max(
        0.70,
        sigma_external_raw * math.sqrt(max(1.0, float(len(ext_logs)) / n_eff_external)),
    )

    mu_hormuz = mu_conflict + math.log(HORMUZ_SHIFT_MULTIPLIER)
    sigma_hormuz = max(HORMUZ_SIGMA_FLOOR, sigma_conflict)

    if conflict_only_mode:
        w_h = WEIGHT_CONFLICT_ONLY_HORMUZ
        w_c = WEIGHT_CONFLICT_ONLY_CONFLICT
        w_gs = 0.0
        w_gt = 0.0
        w_ext = WEIGHT_CONFLICT_ONLY_EXTERNAL if len(selected_external_analogues) > 0 else 0.0
    else:
        w_h = WEIGHT_HORMUZ
        w_c = WEIGHT_CONFLICT
        w_gs = WEIGHT_GLOBAL_SEVERE
        w_gt = WEIGHT_GLOBAL_STRESS if has_stress_class else 0.0
        w_ext = WEIGHT_EXTERNAL_ANALOGUES
    w_sum = max(1e-9, w_h + w_c + w_gs + w_gt + w_ext)

    params = RecoveryModelParams(
        mu_global_severe=mu_global_severe,
        sigma_global_severe=sigma_global_severe,
        mu_global_stress=mu_global_stress,
        sigma_global_stress=sigma_global_stress,
        mu_external=mu_external,
        sigma_external=sigma_external,
        mu_conflict=mu_conflict,
        sigma_conflict=sigma_conflict,
        mu_hormuz=mu_hormuz,
        sigma_hormuz=sigma_hormuz,
        weight_hormuz=w_h / w_sum,
        weight_conflict=w_c / w_sum,
        weight_global_severe=w_gs / w_sum,
        weight_global_stress=w_gt / w_sum,
        weight_external=w_ext / w_sum,
    )

    attenuation_stats: dict[str, Any] = {}

    def fit_att_ratio(
        global_eps: list[Episode],
        conflict_eps: list[Episode],
        horizon_days: int,
    ) -> tuple[float, float, float, dict[str, Any]]:
        global_samples = post_recovery_ratio_samples(global_eps, ratio_df, horizon_days)
        conflict_samples = post_recovery_ratio_samples(conflict_eps, ratio_df, horizon_days)

        global_vals = np.array([s[1] for s in global_samples], dtype=float)
        conflict_vals = np.array([s[1] for s in conflict_samples], dtype=float)

        if len(global_vals) == 0:
            global_vals = np.array([1.0], dtype=float)
        if len(conflict_vals) == 0:
            conflict_vals = global_vals.copy()
            conflict_samples = global_samples

        global_logs = np.log(global_vals)
        conflict_logs_att = np.log(conflict_vals)
        mu_global = float(np.mean(global_logs))
        sigma_global = max(0.06, float(np.std(global_logs, ddof=1)) if len(global_logs) > 1 else 0.06)
        n_eff_conflict_att = max(1.0, effective_sample_size([s[0] for s in conflict_samples]))
        mean_conflict_att = float(np.mean(conflict_logs_att))

        sigma2_att = sigma_global * sigma_global
        tau2_att = ATTENUATION_POOLING_TAU * ATTENUATION_POOLING_TAU
        post_var_att = 1.0 / ((n_eff_conflict_att / sigma2_att) + (1.0 / tau2_att))
        mu_conflict_att = post_var_att * (
            (n_eff_conflict_att * mean_conflict_att / sigma2_att) + (mu_global / tau2_att)
        )

        ratio_global = min(ATTENUATION_RATIO_CAP, max(ATTENUATION_RATIO_FLOOR, math.exp(mu_global)))
        ratio_conflict = min(ATTENUATION_RATIO_CAP, max(ATTENUATION_RATIO_FLOOR, math.exp(mu_conflict_att)))
        ratio_hormuz = ratio_conflict

        class_diag = {
            "horizon_days": horizon_days,
            "global_n": int(len(global_vals)),
            "conflict_n": int(len(conflict_vals)),
            "global_median": float(np.median(global_vals)),
            "conflict_median": float(np.median(conflict_vals)),
            "global_share_below_95pct": float(np.mean(global_vals < 0.95)),
            "conflict_share_below_95pct": float(np.mean(conflict_vals < 0.95)),
        }
        return ratio_global, ratio_conflict, ratio_hormuz, class_diag

    att_global_eps = conflict_affected_eps if conflict_only_mode else fit_eligible
    r30_g, r30_c, r30_h, att30 = fit_att_ratio(att_global_eps, conflict_affected_eps, 30)
    r60_g, r60_c, r60_h, att60 = fit_att_ratio(att_global_eps, conflict_affected_eps, 60)
    r90_g, r90_c, r90_h, att90 = fit_att_ratio(att_global_eps, conflict_affected_eps, 90)

    r60_h = max(r30_h, r60_h)
    r90_h = max(r60_h, r90_h)
    gap30 = max(1e-4, 1.0 - r30_h)
    gap90 = max(1e-4, 1.0 - r90_h)
    if gap90 >= gap30:
        decay_per_day = ATTENUATION_DECAY_FLOOR
    else:
        decay_per_day = -math.log(gap90 / gap30) / 60.0
    decay_per_day = min(ATTENUATION_DECAY_CAP, max(ATTENUATION_DECAY_FLOOR, decay_per_day))
    half_life_days = math.log(2.0) / decay_per_day

    attenuation_params = AttenuationModelParams(
        ratio30_global=r30_g,
        ratio30_conflict=r30_c,
        ratio30_hormuz=r30_h,
        ratio60_global=r60_g,
        ratio60_conflict=r60_c,
        ratio60_hormuz=r60_h,
        ratio90_global=r90_g,
        ratio90_conflict=r90_c,
        ratio90_hormuz=r90_h,
        decay_per_day=decay_per_day,
        half_life_days=half_life_days,
    )
    attenuation_stats = {
        "ratio_horizon_30d": att30,
        "ratio_horizon_60d": att60,
        "ratio_horizon_90d": att90,
        "ratio30_global": r30_g,
        "ratio30_conflict_pooled": r30_c,
        "ratio30_hormuz_used": r30_h,
        "ratio60_hormuz_used": r60_h,
        "ratio90_hormuz_used": r90_h,
        "decay_per_day": decay_per_day,
        "half_life_days": half_life_days,
    }

    def ep_to_row(ep: Episode) -> dict[str, Any]:
        return {
            "portid": ep.portid,
            "portname": ep.portname,
            "start_date_utc": ep.start.date().isoformat(),
            "end_date_utc": ep.end.date().isoformat() if ep.end else None,
            "duration_days": ep.duration_days,
            "min_ratio": ep.min_ratio,
        }

    # Examples emphasized in background info (longer disruptions are most informative for tails).
    conflict_examples = sorted(
        conflict_affected_eps,
        key=lambda ep: (ep.duration_days, -ep.min_ratio),
        reverse=True,
    )[:8]
    global_severe_examples = sorted(
        severe_reference_eps,
        key=lambda ep: (ep.duration_days, -ep.min_ratio),
        reverse=True,
    )[:8]
    global_stress_examples = sorted(
        stress_non_conflict_eps,
        key=lambda ep: (ep.duration_days, -ep.min_ratio),
        reverse=True,
    )[:8]

    if suez_non_conflict_eps and not any(ep.portid == "chokepoint1" for ep in global_stress_examples):
        suez_top = sorted(
            suez_non_conflict_eps,
            key=lambda ep: (ep.duration_days, -ep.min_ratio),
            reverse=True,
        )[0]
        global_stress_examples = ([suez_top] + global_stress_examples)[:8]

    red_sea_route_eps = [
        ep
        for ep in uncensored
        if ep.portid in {"chokepoint1", "chokepoint4"}
        and ep.start.date() >= date(2023, 10, 1)
        and ep.min_ratio <= 0.60
    ]
    red_sea_route_examples = sorted(
        red_sea_route_eps,
        key=lambda ep: (ep.duration_days, -ep.min_ratio),
        reverse=True,
    )[:8]
    black_sea_war_eps = [
        ep
        for ep in uncensored
        if ep.portid == "chokepoint28" and ep.start.date() >= date(2022, 2, 24)
    ]
    black_sea_war_examples = sorted(
        black_sea_war_eps,
        key=lambda ep: (ep.duration_days, -ep.min_ratio),
        reverse=True,
    )[:8]
    gulf_2019_eps = [
        ep
        for ep in uncensored
        if ep.portid == "chokepoint6" and date(2019, 5, 1) <= ep.start.date() <= date(2020, 1, 31)
    ]
    gulf_2019_examples = sorted(
        gulf_2019_eps,
        key=lambda ep: (ep.duration_days, -ep.min_ratio),
        reverse=True,
    )[:8]

    severe_reference_non_conflict_only = (
        False if conflict_only_mode else len(severe_reference_eps) == len(severe_non_conflict_eps)
    )
    overlap_removed_count = (
        0
        if conflict_only_mode
        else len(
            [
                ep
                for ep in fit_eligible
                if ep.portid in CONFLICT_PROXY_PORTS
                and not is_suez_evergiven_non_conflict(ep)
                and ep.min_ratio <= SEVERE_GLOBAL_RATIO
            ]
        )
    )

    severe_reference_days = [float(ep.duration_days) for ep in severe_reference_eps]
    conflict_days = [float(ep.duration_days) for ep in conflict_affected_eps]
    stress_days = [float(ep.duration_days) for ep in stress_non_conflict_eps]

    diagnostics = {
        "fit_eligible_count": len(fit_eligible),
        "reference_class_mode": REFERENCE_CLASS_MODE,
        "excluded_from_fit_count": len(excluded_fit_eps),
        "excluded_from_fit_by_port": {
            str(pid): len([ep for ep in excluded_fit_eps if ep.portid == pid])
            for pid in sorted({ep.portid for ep in excluded_fit_eps})
        },
        "conflict_affected_count": len(conflict_affected_eps),
        "severe_global_non_conflict_count": len(severe_non_conflict_eps),
        "stress_global_non_conflict_count": len(stress_non_conflict_eps),
        "suez_non_conflict_override_count": len(suez_non_conflict_eps),
        "severe_reference_non_conflict_only": severe_reference_non_conflict_only,
        "severe_reference_count": len(severe_reference_eps),
        "overlap_removed_count": overlap_removed_count,
        "effective_sample_fraction_within_port": EFFECTIVE_SAMPLE_FRACTION_WITHIN_PORT,
        "effective_sample_sizes": {
            "conflict_affected": n_eff_conflict,
            "severe_reference": n_eff_global_severe,
            "stress_non_conflict": n_eff_global_stress,
            "external_analogues": n_eff_external,
        },
        "external_overlap_discount": EXTERNAL_OVERLAP_DISCOUNT,
        "post_recovery_attenuation": attenuation_stats,
        "external_analogues": [
            {
                "key": str(item["key"]),
                "name": str(item["name"]),
                "duration_days": int((item["end_date"] - item["start_date"]).days + 1),
                "overlap_with_portwatch": bool(item.get("overlap_with_portwatch")),
                "conflict_related": bool(item.get("conflict_related")),
                "source_label": str(item["source_label"]),
                "source_url": str(item["source_url"]),
            }
            for item in selected_external_analogues
        ],
        "severe_reference_duration_days": {
            "median": float(np.median(severe_reference_days)),
            "p90": float(np.quantile(severe_reference_days, 0.90)),
        },
        "conflict_affected_duration_days": {
            "median": float(np.median(conflict_days)),
            "p90": float(np.quantile(conflict_days, 0.90)),
        },
        "stress_non_conflict_duration_days": (
            {
                "median": float(np.median(stress_days)),
                "p90": float(np.quantile(stress_days, 0.90)),
            }
            if stress_days
            else None
        ),
        "conflict_proxy_examples_top_duration": [ep_to_row(ep) for ep in conflict_examples],
        "global_severe_examples_top_duration": [ep_to_row(ep) for ep in global_severe_examples],
        "global_stress_examples_top_duration": [ep_to_row(ep) for ep in global_stress_examples],
        "red_sea_route_examples_top_duration": [ep_to_row(ep) for ep in red_sea_route_examples],
        "black_sea_war_examples_top_duration": [ep_to_row(ep) for ep in black_sea_war_examples],
        "gulf_2019_examples_top_duration": [ep_to_row(ep) for ep in gulf_2019_examples],
        "analogue_coverage": {
            "tanker_war_1984_1988": {
                "included_in_model": any(item["key"] == "tanker_war_1984_1988" for item in selected_external_analogues),
                "mode": "external_duration_prior",
            },
            "red_sea_bab_after_houthi": {
                "included_in_model": any(item["key"] == "red_sea_bab_after_houthi" for item in selected_external_analogues),
                "mode": "portwatch_plus_external_prior",
                "episodes_detected": len(red_sea_route_eps),
            },
            "black_sea_after_2022_invasion": {
                "included_in_model": any(item["key"] == "black_sea_after_2022_invasion" for item in selected_external_analogues),
                "mode": "portwatch_plus_external_prior",
                "episodes_detected": len(black_sea_war_eps),
            },
            "gulf_tanker_crisis_2019": {
                "included_in_model": any(item["key"] == "gulf_tanker_crisis_2019" for item in selected_external_analogues),
                "mode": "external_duration_prior",
                "episodes_detected": len(gulf_2019_eps),
            },
            "ever_given_suez_blockage_2021": {
                "included_in_model": any(item["key"] == "ever_given_suez_blockage_2021" for item in selected_external_analogues),
                "mode": "portwatch_plus_external_prior",
                "episodes_detected": len(suez_non_conflict_eps),
            },
            "panama_drought_restrictions": {
                "included_in_model": any(item["key"] == "panama_drought_restrictions" for item in selected_external_analogues),
                "mode": "external_duration_prior",
            },
        },
    }

    return params, attenuation_params, diagnostics


def fit_regime_params(hormuz_df: pd.DataFrame) -> RegimeParams:
    baseline = hormuz_df[
        (hormuz_df["date"] >= BASELINE_START) & (hormuz_df["date"] <= BASELINE_END)
    ]["n_total"].to_numpy(dtype=float)

    collapse = hormuz_df[
        (hormuz_df["date"] >= COLLAPSE_START) & (hormuz_df["date"] <= COLLAPSE_END)
    ]["n_total"].to_numpy(dtype=float)

    if len(baseline) < 10:
        raise RuntimeError("Insufficient baseline observations for regime fitting")
    if len(collapse) < 3:
        collapse = hormuz_df.sort_values("date").tail(7)["n_total"].to_numpy(dtype=float)

    baseline_mean = float(np.mean(baseline))
    collapse_mean = float(np.mean(collapse))
    baseline_var = float(np.var(baseline, ddof=1))
    collapse_var = float(np.var(collapse, ddof=1))

    baseline_var = max(baseline_var, baseline_mean * 1.05)
    collapse_var = max(collapse_var, collapse_mean * 1.20)

    baseline_size = max(1.0, (baseline_mean * baseline_mean) / max(1e-9, baseline_var - baseline_mean))
    collapse_size = max(1.0, (collapse_mean * collapse_mean) / max(1e-9, collapse_var - collapse_mean))

    return RegimeParams(
        baseline_mean=baseline_mean,
        baseline_var=baseline_var,
        collapse_mean=collapse_mean,
        collapse_var=collapse_var,
        baseline_size=baseline_size,
        collapse_size=collapse_size,
    )


def lognormal_cdf(x: float, mu: float, sigma: float) -> float:
    if x <= 0:
        return 0.0
    z = (math.log(x) - mu) / sigma
    return NormalDist().cdf(z)


def recovery_cdf(days: float, params: RecoveryModelParams) -> float:
    c_h = lognormal_cdf(days, params.mu_hormuz, params.sigma_hormuz)
    c_c = lognormal_cdf(days, params.mu_conflict, params.sigma_conflict)
    c_gs = lognormal_cdf(days, params.mu_global_severe, params.sigma_global_severe)
    c_gt = lognormal_cdf(days, params.mu_global_stress, params.sigma_global_stress)
    c_ext = lognormal_cdf(days, params.mu_external, params.sigma_external)
    return (
        params.weight_hormuz * c_h
        + params.weight_conflict * c_c
        + params.weight_global_severe * c_gs
        + params.weight_global_stress * c_gt
        + params.weight_external * c_ext
    )


def conditioned_recovery_cdf(days: float, known_survival_days: int, params: RecoveryModelParams) -> float:
    if days <= known_survival_days:
        return 0.0
    c0 = recovery_cdf(float(known_survival_days), params)
    c1 = recovery_cdf(days, params)
    denom = max(1e-9, 1.0 - c0)
    return max(0.0, min(1.0, (c1 - c0) / denom))


def render_charts(
    hormuz_df: pd.DataFrame,
    episodes: list[Episode],
    known_survival_days: int,
    recovery_params: RecoveryModelParams,
    out_dir: Path,
) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    hormuz_chart_path = out_dir / "hormuz_calls_recent.png"
    recovery_fit_path = out_dir / "recovery_duration_fit.png"

    # Chart 1: recent Hormuz daily calls with collapse marker.
    recent = hormuz_df[hormuz_df["date"] >= datetime(2026, 1, 1, tzinfo=timezone.utc)].copy()
    recent["roll7"] = recent["n_total"].rolling(7, min_periods=1).mean()
    fig, ax = plt.subplots(figsize=(10, 4.8))
    ax.plot(recent["date"], recent["n_total"], color="#0b84a5", linewidth=1.8, label="Daily n_total")
    ax.plot(recent["date"], recent["roll7"], color="#f6c85f", linewidth=2.2, label="7-day mean")
    ax.axvline(EVENT_START_DATE, color="#d7191c", linestyle="--", linewidth=1.5, label="Collapse start (2026-03-01)")
    ax.set_title("Strait of Hormuz transit calls (PortWatch)")
    ax.set_ylabel("Daily transit calls (n_total)")
    ax.set_xlabel("Date (UTC)")
    ax.grid(alpha=0.25)
    ax.legend(loc="upper right", frameon=False)
    fig.tight_layout()
    fig.savefig(hormuz_chart_path, dpi=160)
    plt.close(fig)

    # Chart 2: recovery-duration reference classes with fitted recovery CDF.
    uncensored = [ep for ep in episodes if not ep.censored]
    severe_non_conflict: list[int] = []
    if REFERENCE_CLASS_MODE != "conflict_only":
        severe_non_conflict = [
            ep.duration_days
            for ep in uncensored
            if ep.portid not in CONFLICT_PROXY_PORTS and ep.min_ratio <= SEVERE_GLOBAL_RATIO
        ]
    conflict_affected = [
        ep.duration_days
        for ep in uncensored
        if ep.portid in CONFLICT_PROXY_PORTS
        and not is_suez_evergiven_non_conflict(ep)
        and ep.min_ratio <= SEVERE_CONFLICT_RATIO
    ]
    horizon_days = np.arange(max(known_survival_days + 1, 9), 366)
    cdf = np.array(
        [
            conditioned_recovery_cdf(float(day), known_survival_days, recovery_params)
            for day in horizon_days
        ]
    )

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.8))
    bins = np.arange(0, 241, 10)
    if REFERENCE_CLASS_MODE != "conflict_only":
        ax1.hist(severe_non_conflict, bins=bins, alpha=0.55, color="#4c78a8", label="Non-conflict severe")
    ax1.hist(conflict_affected, bins=bins, alpha=0.55, color="#f58518", label="Conflict-affected")
    ax1.set_title(
        "Historical disruption durations (conflict-only)"
        if REFERENCE_CLASS_MODE == "conflict_only"
        else "Historical disruption durations"
    )
    ax1.set_xlabel("Recovery duration (days)")
    ax1.set_ylabel("Episode count")
    ax1.grid(alpha=0.25)
    ax1.legend(frameon=False)

    ax2.plot(horizon_days, 100.0 * cdf, color="#54a24b", linewidth=2.4)
    ax2.set_title("Modeled P(recovered by day)")
    ax2.set_xlabel("Days since 2026-03-01")
    ax2.set_ylabel("Probability (%)")
    ax2.set_ylim(0, 100)
    ax2.grid(alpha=0.25)

    fig.tight_layout()
    fig.savefig(recovery_fit_path, dpi=160)
    plt.close(fig)


def nb_pmf(x: int, mean: float, size: float) -> float:
    if x < 0:
        return 0.0
    if mean <= 0:
        return 1.0 if x == 0 else 0.0

    p = mean / (mean + size)
    log_pmf = (
        math.lgamma(x + size)
        - math.lgamma(size)
        - math.lgamma(x + 1)
        + size * math.log(1.0 - p)
        + x * math.log(p)
    )
    return math.exp(log_pmf)


def nb_tail_geq(threshold: int, mean: float, size: float) -> float:
    if threshold <= 0:
        return 1.0

    cdf = 0.0
    for x in range(threshold):
        cdf += nb_pmf(x, mean, size)
    return max(0.0, min(1.0, 1.0 - cdf))


def day_end_iso(d: date) -> str:
    return f"{d.isoformat()}T23:59:59Z"


def first_of_month(d: date) -> date:
    return d.replace(day=1)


def next_month(d: date) -> date:
    if d.month == 12:
        return d.replace(year=d.year + 1, month=1, day=1)
    return d.replace(month=d.month + 1, day=1)


def end_of_month(d: date) -> date:
    return next_month(first_of_month(d)) - timedelta(days=1)


def infer_context_links(example: dict[str, Any]) -> list[tuple[str, str]]:
    """Return conservative context links for a historical episode."""
    portid = str(example.get("portid") or "").strip()
    raw_start = str(example.get("start_date_utc") or "").strip()
    try:
        start_d = date.fromisoformat(raw_start)
    except ValueError:
        return []

    links: list[tuple[str, str]] = []

    if portid == "chokepoint13":
        if date(2022, 11, 1) <= start_d <= date(2023, 5, 31):
            links.append(("AP: Heavy snow disrupted flights and rail in Japan", AP_JAPAN_SNOW_URL))
        elif date(2025, 11, 1) <= start_d <= date(2026, 3, 31):
            links.append(("AP: Heavy snow in Hokkaido disrupted transport", AP_JAPAN_HOKKAIDO_SNOW_2025_URL))
        elif date(2020, 1, 1) <= start_d <= date(2021, 12, 31):
            links.append(("IMF WP: high-frequency shipping disruptions during the pandemic", IMF_COVID_SHIPPING_WP_URL))

    if portid in {"chokepoint12", "chokepoint14", "chokepoint27"} and date(2020, 1, 1) <= start_d <= date(2022, 6, 30):
        links.append(("UNCTAD: COVID-19 cut maritime trade in 2020", UNCTAD_COVID_MARITIME_2020_URL))

    if portid == "chokepoint28" and start_d >= date(2022, 2, 24):
        links.append(("AP: Ukraine war-related attacks on Crimea bridge", AP_KERCH_TRAFFIC_URL))

    if portid in {"chokepoint4", "chokepoint7"} and start_d >= date(2023, 10, 1):
        links.append(("AP: Red Sea attacks forced rerouting around Africa", AP_RED_SEA_REROUTE_URL))
        links.append(("AP: UN says Red Sea attacks reshaped global trade routes", AP_RED_SEA_UN_TRADE_URL))

    if portid == "chokepoint1" and date(2021, 3, 20) <= start_d <= date(2021, 4, 5):
        links.append(("AP: Ever Given blockage and legal aftermath", AP_SUEZ_EVERGIVEN_2023_URL))

    if portid == "chokepoint3" and date(2022, 12, 1) <= start_d <= date(2023, 1, 31):
        links.append(("Reuters: Bosphorus tanker delays from insurance checks", REUTERS_BOSPHORUS_2022_URL))

    if portid == "chokepoint11" and date(2024, 10, 1) <= start_d <= date(2024, 10, 31):
        links.append(("AP: China military drills around Taiwan (Oct 2024)", AP_TAIWAN_DRILLS_2024_URL))

    unique: list[tuple[str, str]] = []
    seen: set[str] = set()
    for label, url in links:
        if url in seen:
            continue
        seen.add(url)
        unique.append((label, url))

    return unique[:2]


def sample_nb(
    rng: np.random.Generator,
    mean: float,
    size: float,
    shape: tuple[int, int],
) -> np.ndarray:
    lam = rng.gamma(shape=size, scale=(mean / size), size=shape)
    return rng.poisson(lam)


def build_daily_rows(
    forecast_start: date,
    recovery_params: RecoveryModelParams,
    attenuation_params: AttenuationModelParams,
    known_survival_days: int,
    regime: RegimeParams,
) -> list[dict[str, str]]:
    rows: list[dict[str, str]] = []

    low_tail = {
        threshold: nb_tail_geq(threshold, regime.collapse_mean, regime.collapse_size)
        for threshold in COUNT_THRESHOLDS
    }
    state_cache: dict[int, tuple[float, float]] = {}

    current = forecast_start
    while current <= FORECAST_END_DATE:
        threshold_dt = datetime.combine(current, datetime.max.time()).replace(
            hour=23,
            minute=59,
            second=59,
            microsecond=0,
            tzinfo=timezone.utc,
        )
        days_since_event = (threshold_dt - EVENT_START_DATE).days + 1
        if days_since_event not in state_cache:
            p_recovered = conditioned_recovery_cdf(
                float(days_since_event),
                known_survival_days=known_survival_days,
                params=recovery_params,
            )
            recovered_mult = expected_recovered_multiplier(
                days_since_event=days_since_event,
                known_survival_days=known_survival_days,
                recovery_params=recovery_params,
                attenuation_params=attenuation_params,
            )
            recovered_mean = regime.baseline_mean * recovered_mult
            state_cache[days_since_event] = (p_recovered, recovered_mean)
        p_recovered, recovered_mean = state_cache[days_since_event]

        projection_group = day_end_iso(current)
        end_date_iso = day_end_iso(current)

        prev_prob = 1.0
        for threshold in COUNT_THRESHOLDS:
            recovered_tail = nb_tail_geq(threshold, recovered_mean, regime.baseline_size)
            prob = (1.0 - p_recovered) * low_tail[threshold] + p_recovered * recovered_tail
            prob = min(prob, prev_prob)
            prob = min(0.999, max(0.001, prob))
            prev_prob = prob

            rows.append(
                {
                    "projection_group": projection_group,
                    "threshold_decimal": str(threshold),
                    "threshold_date": "",
                    "initial_probability": f"{prob:.{PROBABILITY_DECIMALS}f}",
                    "label": f"{current.isoformat()}, >= {threshold}",
                    "end_date": end_date_iso,
                    "decay_rate": f"{MARKET_DECAY_RATE:.3f}",
                    "status": "open",
                }
            )

        current += timedelta(days=1)

    return rows


def start_of_week_monday(value: date) -> date:
    return value - timedelta(days=value.weekday())


def end_of_week_sunday(value: date) -> date:
    return start_of_week_monday(value) + timedelta(days=6)


def next_week_monday(value: date) -> date:
    return start_of_week_monday(value) + timedelta(days=7)


def iso_week_projection_group(value: date) -> str:
    iso_year, iso_week, _ = value.isocalendar()
    return f"{iso_year}-W{iso_week:02d}"


def build_weekly_sum_rows(
    forecast_start: date,
    hormuz_df: pd.DataFrame,
    recovery_params: RecoveryModelParams,
    attenuation_params: AttenuationModelParams,
    known_survival_days: int,
    regime: RegimeParams,
) -> list[dict[str, str]]:
    rows: list[dict[str, str]] = []
    rng = np.random.default_rng(42)
    n_sims = 120_000
    state_cache: dict[int, tuple[float, float]] = {}

    observed: dict[date, float] = {}
    for _, row in hormuz_df.iterrows():
        observed[row["date"].date()] = float(row["n_total"])

    week_cursor = start_of_week_monday(forecast_start)
    while week_cursor <= FORECAST_END_DATE:
        week_end = end_of_week_sunday(week_cursor)
        if week_end > FORECAST_END_DATE:
            break
        week_dates: list[date] = []
        d = week_cursor
        while d <= week_end:
            week_dates.append(d)
            d += timedelta(days=1)

        observed_sum = 0.0
        unknown_dates: list[date] = []
        for day in week_dates:
            if day in observed:
                observed_sum += observed[day]
            else:
                unknown_dates.append(day)

        sim_sum = np.zeros(n_sims, dtype=float)
        if unknown_dates:
            p_list: list[float] = []
            mean_list: list[float] = []
            for day in unknown_dates:
                days_since_event = (
                    datetime.combine(day, datetime.max.time()).replace(
                        hour=23,
                        minute=59,
                        second=59,
                        microsecond=0,
                        tzinfo=timezone.utc,
                    )
                    - EVENT_START_DATE
                ).days + 1
                if days_since_event not in state_cache:
                    p_recovered = conditioned_recovery_cdf(
                        float(days_since_event),
                        known_survival_days=known_survival_days,
                        params=recovery_params,
                    )
                    recovered_mult = expected_recovered_multiplier(
                        days_since_event=days_since_event,
                        known_survival_days=known_survival_days,
                        recovery_params=recovery_params,
                        attenuation_params=attenuation_params,
                    )
                    state_cache[days_since_event] = (
                        p_recovered,
                        regime.baseline_mean * recovered_mult,
                    )
                p_recovered, recovered_mean = state_cache[days_since_event]
                p_list.append(p_recovered)
                mean_list.append(recovered_mean)

            p_arr = np.array(p_list, dtype=float)
            recovered_means = np.array(mean_list, dtype=float)
            rec_mask = rng.random((n_sims, len(unknown_dates))) < p_arr
            low = sample_nb(
                rng=rng,
                mean=regime.collapse_mean,
                size=regime.collapse_size,
                shape=(n_sims, len(unknown_dates)),
            )
            lam_high = rng.gamma(
                shape=regime.baseline_size,
                scale=(recovered_means / regime.baseline_size),
                size=(n_sims, len(unknown_dates)),
            )
            high = rng.poisson(lam_high)
            sim_sum = np.where(rec_mask, high, low).sum(axis=1, dtype=float)

        week_total = observed_sum + sim_sum
        projection_group = iso_week_projection_group(week_end)
        end_date_iso = day_end_iso(week_end)

        prev_prob = 1.0
        for threshold in WEEKLY_SUM_THRESHOLDS:
            prob = float(np.mean(week_total >= float(threshold)))
            prob = min(prob, prev_prob)
            prob = min(0.999, max(0.001, prob))
            prev_prob = prob
            rows.append(
                {
                    "projection_group": projection_group,
                    "threshold_decimal": str(threshold),
                    "threshold_date": "",
                    "initial_probability": f"{prob:.{PROBABILITY_DECIMALS}f}",
                    "label": f"Week ending {week_end.isoformat()}, total calls >= {threshold}",
                    "end_date": end_date_iso,
                    "decay_rate": f"{MARKET_DECAY_RATE:.3f}",
                    "status": "open",
                }
            )

        week_cursor = next_week_monday(week_cursor)

    return rows


def build_monthly_average_rows(
    forecast_start: date,
    hormuz_df: pd.DataFrame,
    recovery_params: RecoveryModelParams,
    attenuation_params: AttenuationModelParams,
    known_survival_days: int,
    regime: RegimeParams,
) -> list[dict[str, str]]:
    rows: list[dict[str, str]] = []
    rng = np.random.default_rng(42)
    n_sims = 120_000
    state_cache: dict[int, tuple[float, float]] = {}

    observed: dict[date, float] = {}
    for _, row in hormuz_df.iterrows():
        observed[row["date"].date()] = float(row["n_total"])

    month_cursor = first_of_month(forecast_start)
    while month_cursor <= FORECAST_END_DATE:
        month_end = min(end_of_month(month_cursor), FORECAST_END_DATE)
        month_dates: list[date] = []
        d = month_cursor
        while d <= month_end:
            month_dates.append(d)
            d += timedelta(days=1)

        observed_sum = 0.0
        unknown_dates: list[date] = []
        for day in month_dates:
            if day in observed:
                observed_sum += observed[day]
            else:
                unknown_dates.append(day)

        sim_sum = np.zeros(n_sims, dtype=float)
        if unknown_dates:
            p_list: list[float] = []
            mean_list: list[float] = []
            for day in unknown_dates:
                days_since_event = (
                    datetime.combine(day, datetime.max.time()).replace(
                        hour=23,
                        minute=59,
                        second=59,
                        microsecond=0,
                        tzinfo=timezone.utc,
                    )
                    - EVENT_START_DATE
                ).days + 1
                if days_since_event not in state_cache:
                    p_recovered = conditioned_recovery_cdf(
                        float(days_since_event),
                        known_survival_days=known_survival_days,
                        params=recovery_params,
                    )
                    recovered_mult = expected_recovered_multiplier(
                        days_since_event=days_since_event,
                        known_survival_days=known_survival_days,
                        recovery_params=recovery_params,
                        attenuation_params=attenuation_params,
                    )
                    state_cache[days_since_event] = (p_recovered, regime.baseline_mean * recovered_mult)
                p_recovered, recovered_mean = state_cache[days_since_event]
                p_list.append(p_recovered)
                mean_list.append(recovered_mean)

            p_arr = np.array(p_list, dtype=float)
            recovered_means = np.array(mean_list, dtype=float)
            rec_mask = rng.random((n_sims, len(unknown_dates))) < p_arr
            low = sample_nb(
                rng=rng,
                mean=regime.collapse_mean,
                size=regime.collapse_size,
                shape=(n_sims, len(unknown_dates)),
            )
            lam_high = rng.gamma(
                shape=regime.baseline_size,
                scale=(recovered_means / regime.baseline_size),
                size=(n_sims, len(unknown_dates)),
            )
            high = rng.poisson(lam_high)
            sim_sum = np.where(rec_mask, high, low).sum(axis=1, dtype=float)

        month_avg = (observed_sum + sim_sum) / float(len(month_dates))
        projection_group = month_cursor.strftime("%Y-%m")
        end_date_iso = day_end_iso(month_end)

        prev_prob = 1.0
        for threshold in COUNT_THRESHOLDS:
            prob = float(np.mean(month_avg >= float(threshold)))
            prob = min(prob, prev_prob)
            prob = min(0.999, max(0.001, prob))
            prev_prob = prob
            rows.append(
                {
                    "projection_group": projection_group,
                    "threshold_decimal": str(threshold),
                    "threshold_date": "",
                    "initial_probability": f"{prob:.{PROBABILITY_DECIMALS}f}",
                    "label": f"{projection_group} average daily calls >= {threshold}",
                    "end_date": end_date_iso,
                    "decay_rate": f"{MARKET_DECAY_RATE:.3f}",
                    "status": "open",
                }
            )

        month_cursor = next_month(month_cursor)

    return rows


def build_resolution_criteria(mode: str) -> str:
    api_info_url = "https://www.arcgis.com/sharing/rest/content/items/42132aa4e2fc4d41bdaf9a445f688931?f=json"
    portwatch_page_url = PORTWATCH_PAGE_URL

    if mode == "monthly_average":
        return (
            "<p>This market resolves using the <code>n_total</code> variable (daily transit calls) for "
            "<code>portid='chokepoint6'</code> (Strait of Hormuz) from the PortWatch "
            "'Daily Chokepoint Transit Calls and Trade Volume Estimates' "
            f"<a href=\"{api_info_url}\">API</a>.</p>"
            "<p>For each month (projection_group = YYYY-MM), the resolved value is the arithmetic mean of "
            "<code>n_total</code> across all UTC days in that month. Each threshold row resolves YES if the monthly "
            "average is greater than or equal to <code>threshold_decimal</code>, otherwise NO.</p>"
            f"<p>Reference page: <a href=\"{portwatch_page_url}\">IMF PortWatch chokepoint page (Strait of Hormuz)</a>.</p>"
            "<p>If historical values are revised, resolution values will be updated accordingly.</p>"
        )

    if mode == "weekly_sum":
        return (
            "<p>This market resolves using the <code>n_total</code> variable (daily transit calls) for "
            "<code>portid='chokepoint6'</code> (Strait of Hormuz) from the PortWatch "
            "'Daily Chokepoint Transit Calls and Trade Volume Estimates' "
            f"<a href=\"{api_info_url}\">API</a>.</p>"
            "<p>Each row corresponds to one ISO week in UTC, ending on Sunday "
            "(projection_group = YYYY-Www). The resolved value for each weekly group is the sum of "
            "<code>n_total</code> across all UTC days from Monday through Sunday of that ISO week. "
            "Each threshold row resolves YES if the weekly total is greater than or equal to "
            "<code>threshold_decimal</code>, otherwise NO.</p>"
            f"<p>Reference page: <a href=\"{portwatch_page_url}\">IMF PortWatch chokepoint page (Strait of Hormuz)</a>.</p>"
            "<p>The linked PortWatch chart includes tanker traffic, but its hover readout does not show numeric "
            "values for every vessel category at once. This market resolves on overall <code>n_total</code> "
            "transit calls, not any single category-specific hover value.</p>"
            "<p>If historical values are revised, resolution values will be updated accordingly.</p>"
        )

    return (
        "<p>This market resolves using the total daily transit calls for "
        "<code>portid='chokepoint6'</code> (Strait of Hormuz) from the PortWatch "
        "'Daily Chokepoint Transit Calls and Trade Volume Estimates' "
        f"<a href=\"{api_info_url}\">API</a>.</p>"
        f"<p>Reference page: <a href=\"{portwatch_page_url}\">IMF PortWatch chokepoint page (Strait of Hormuz)</a>.</p>"
        "<p>If historical values are revised, resolution values will be updated accordingly.</p>"
    )


def build_background_html(
    mode: str,
    market_title: str,
    hormuz_df: pd.DataFrame,
    known_survival_days: int,
    recovery_params: RecoveryModelParams,
    attenuation_params: AttenuationModelParams,
    regime: RegimeParams,
    diagnostics: dict[str, Any],
    forecast_start: date,
) -> str:
    latest_date = hormuz_df["date"].max().to_pydatetime()

    pre = int(hormuz_df[hormuz_df["date"] == datetime(2026, 2, 28, tzinfo=timezone.utc)]["n_total"].iloc[0])
    day1 = int(hormuz_df[hormuz_df["date"] == datetime(2026, 3, 1, tzinfo=timezone.utc)]["n_total"].iloc[0])
    week_mean = float(
        hormuz_df[
            (hormuz_df["date"] >= COLLAPSE_START) & (hormuz_df["date"] <= COLLAPSE_END)
        ]["n_total"].mean()
    )

    def pct(x: float) -> str:
        return f"{100.0 * x:.1f}%"

    days_for_preview = [30, 60, 90, 120, 180]
    rec_preview = [
        conditioned_recovery_cdf(float(d), known_survival_days, recovery_params)
        for d in days_for_preview
    ]
    recovered_mult_preview = [
        expected_recovered_multiplier(
            days_since_event=d,
            known_survival_days=known_survival_days,
            recovery_params=recovery_params,
            attenuation_params=attenuation_params,
        )
        for d in days_for_preview
    ]

    conflict_examples = diagnostics.get("conflict_proxy_examples_top_duration", [])
    global_severe_examples = diagnostics.get("global_severe_examples_top_duration", [])
    global_stress_examples = diagnostics.get("global_stress_examples_top_duration", [])
    red_sea_examples = diagnostics.get("red_sea_route_examples_top_duration", [])
    analogue_coverage = diagnostics.get("analogue_coverage", {})
    external_analogues = diagnostics.get("external_analogues", [])
    attenuation_diag = diagnostics.get("post_recovery_attenuation", {})
    att30 = attenuation_diag.get("ratio_horizon_30d", {})
    att60 = attenuation_diag.get("ratio_horizon_60d", {})
    att90 = attenuation_diag.get("ratio_horizon_90d", {})
    conflict_only_mode = diagnostics.get("reference_class_mode") == "conflict_only"

    def render_example_rows(items: list[dict[str, Any]]) -> str:
        if not items:
            return "<tr><td colspan='6'>No examples available</td></tr>"

        rendered: list[str] = []
        for item in items:
            contexts = infer_context_links(item)
            if contexts:
                context_text = "<br/>".join(
                    f"<a href=\"{html.escape(url, quote=True)}\">{html.escape(label)}</a>"
                    for label, url in contexts
                )
            else:
                context_text = "N/A"
            rendered.append(
                "<tr>"
                f"<td>{html.escape(str(item['portname']))}</td>"
                f"<td>{html.escape(str(item['start_date_utc']))}</td>"
                f"<td>{html.escape(str(item.get('end_date_utc') or 'ongoing/censored'))}</td>"
                f"<td>{item['duration_days']}</td>"
                f"<td>{100.0 * float(item['min_ratio']):.1f}%</td>"
                f"<td>{context_text}</td>"
                "</tr>"
            )
        return "".join(rendered)

    stress_summary = None if conflict_only_mode else diagnostics.get("stress_non_conflict_duration_days")
    stress_class_li = (
        f"<li>Broader non-conflict stress episodes: {diagnostics['stress_global_non_conflict_count']} episodes "
        "with traffic between 20% and 35% of baseline.</li>"
        if stress_summary
        else ""
    )
    stress_duration_li = (
        f"<li>Broader non-conflict stress duration summary: median {stress_summary['median']:.1f} days, "
        f"90th percentile {stress_summary['p90']:.1f} days.</li>"
        if stress_summary
        else ""
    )
    suez_override_li = (
        f"<li>Non-conflict override: {diagnostics.get('suez_non_conflict_override_count', 0)} Suez blockage episode(s) "
        "from the 2021 Ever Given disruption were added to the non-conflict stress class.</li>"
        if (not conflict_only_mode and diagnostics.get("suez_non_conflict_override_count", 0) > 0)
        else ""
    )
    reference_duration_li = (
        f"<li>Conflict reference-class duration summary: median {diagnostics['severe_reference_duration_days']['median']:.1f} days, "
        f"90th percentile {diagnostics['severe_reference_duration_days']['p90']:.1f} days.</li>"
        if conflict_only_mode
        else f"<li>Severe class duration summary: median {diagnostics['severe_reference_duration_days']['median']:.1f} days, 90th percentile {diagnostics['severe_reference_duration_days']['p90']:.1f} days.</li>"
    )
    conflict_duration_li = (
        ""
        if conflict_only_mode
        else f"<li>Conflict-affected duration summary: median {diagnostics['conflict_affected_duration_days']['median']:.1f} days, 90th percentile {diagnostics['conflict_affected_duration_days']['p90']:.1f} days.</li>"
    )
    stress_table_html = (
        f"""
<h3>Key broader non-conflict stress examples</h3>
<table>
  <thead>
    <tr><th>Bottleneck</th><th>Start</th><th>End</th><th>Duration (days)</th><th>Lowest traffic vs pre-shock baseline</th><th>Possible context (news/reports)</th></tr>
  </thead>
  <tbody>
    {render_example_rows(global_stress_examples)}
  </tbody>
</table>
"""
        if stress_summary
        else ""
    )
    red_sea_table_html = (
        f"""
<h3>Key Red Sea route disruption examples (Suez and Bab el-Mandeb)</h3>
<table>
  <thead>
    <tr><th>Bottleneck</th><th>Start</th><th>End</th><th>Duration (days)</th><th>Lowest traffic vs pre-shock baseline</th><th>Possible context (news/reports)</th></tr>
  </thead>
  <tbody>
    {render_example_rows(red_sea_examples)}
  </tbody>
</table>
"""
        if red_sea_examples
        else ""
    )
    additional_analogues_table_html = f"""
<h3>{"Additional conflict analogues considered" if conflict_only_mode else "Additional analogues considered"}</h3>
<table>
  <thead>
    <tr><th>Analogue</th><th>Representative duration (days)</th><th>Included in model?</th><th>How included</th><th>Source</th></tr>
  </thead>
  <tbody>
    {"".join(
        (
            "<tr>"
            f"<td>{html.escape(str(item.get('name', 'N/A')))}</td>"
            f"<td>{int(item.get('duration_days', 0))}</td>"
            f"<td>{'Yes' if analogue_coverage.get(str(item.get('key')), {}).get('included_in_model') else 'No'}</td>"
            f"<td>{html.escape(str(analogue_coverage.get(str(item.get('key')), {}).get('mode', 'external_duration_prior')))}</td>"
            f"<td><a href=\"{html.escape(str(item.get('source_url', '')), quote=True)}\">{html.escape(str(item.get('source_label', 'source')))}</a></td>"
            "</tr>"
        )
        for item in external_analogues
    )}
  </tbody>
</table>
"""

    severe_class_label = (
        "Conflict-affected disruptions used for fitting"
        if conflict_only_mode
        else (
            "Non-conflict severe disruptions"
            if diagnostics.get("severe_reference_non_conflict_only", False)
            else "Severe disruptions (fallback class)"
        )
    )
    overlap_note = (
        ""
        if conflict_only_mode
        else (
            f"{diagnostics.get('overlap_removed_count', 0)} conflict-affected severe episodes were excluded from the global severe class "
            "to avoid double counting."
        )
    )
    class_paragraph = (
        "This run intentionally uses a conflict-only reference class for disruption duration and post-recovery behavior. "
        "That keeps starting probabilities anchored to geopolitically similar episodes, rather than blending with broader non-conflict disruptions."
        if conflict_only_mode
        else (
            "Why multiple classes? The severe global class gives a broad base rate. The conflict-affected class captures "
            "bottlenecks that are geopolitically more similar to Hormuz. The broader stress class (when available) widens "
            "coverage beyond only the most extreme disruptions."
        )
    )
    context_paragraph = (
        "Most high-weight examples in this run are conflict-linked disruptions in Suez, Bab el-Mandeb, Bosporus, and Kerch, "
        "plus longer conflict analogues such as the Tanker War and the post-2023 Red Sea attacks."
        if conflict_only_mode
        else (
            "What happened in many of these examples? A substantial share of 2020 to 2022 entries coincide with pandemic-era "
            "shipping disruption. Several Kerch Strait disruptions from 2022 onward overlap with the Russia-Ukraine war period. "
            "Tsugaru Strait examples are often winter episodes, and published sailing directions "
            f"(<a href=\"{JCG_TSUGARU_NAV_GUIDE_URL}\">NGA Pub 120</a>) note strong currents, dense fog, and seasonal snowstorms "
            "in this route. To avoid over-linking one generic document, the table uses date-specific news/report links where "
            "available and leaves rows as N/A when no clear period-specific source is identified. These links are context only "
            "and are not treated as definitive causal attribution in the model."
        )
    )
    effective_sample_text = (
        "Effective sample size in the conflict class was "
        f"{diagnostics['effective_sample_sizes']['conflict_affected']:.1f}. "
        "External conflict analogues use overlap discounting: if partly represented in PortWatch, "
        f"their weight is reduced to {100.0 * diagnostics.get('external_overlap_discount', EXTERNAL_OVERLAP_DISCOUNT):.0f}%."
        if conflict_only_mode
        else (
            "Effective sample sizes used in fitting were: severe class "
            f"{diagnostics['effective_sample_sizes']['severe_reference']:.1f}, conflict-affected "
            f"{diagnostics['effective_sample_sizes']['conflict_affected']:.1f}, and broader stress "
            f"{diagnostics['effective_sample_sizes']['stress_non_conflict']:.1f}. External analogues use an overlap discount: "
            "when an analogue is already partly represented in PortWatch, it is down-weighted to "
            f"{100.0 * diagnostics.get('external_overlap_discount', EXTERNAL_OVERLAP_DISCOUNT):.0f}% in the external prior."
        )
    )
    market_target_paragraph = (
        "For this weekly market, the model forecasts daily <code>n_total</code> first and then sums those daily calls across each ISO week (Monday-Sunday UTC) to produce the traded weekly totals."
        if mode == "weekly_sum"
        else (
            "For this market, the model forecasts daily <code>n_total</code> first and then averages those daily calls across each ISO week (Monday-Sunday UTC) to produce the traded weekly values."
            if mode == "weekly_average"
            else ""
        )
    )
    attenuation_pooling_text = (
        "Because this run is conflict-only, the attenuation fit uses conflict episode ratios directly (with shrinkage toward their pooled conflict mean)."
        if conflict_only_mode
        else "These are pooled so the conflict estimate is informed by broader history but not dominated by it."
    )
    weight_labels: list[str] = []
    if recovery_params.weight_hormuz > 0.001:
        weight_labels.append(f"<strong>{100.0 * recovery_params.weight_hormuz:.1f}%</strong> current-Hormuz stress assumption")
    if recovery_params.weight_conflict > 0.001:
        weight_labels.append(f"<strong>{100.0 * recovery_params.weight_conflict:.1f}%</strong> conflict-affected history")
    if recovery_params.weight_global_severe > 0.001:
        weight_labels.append(f"<strong>{100.0 * recovery_params.weight_global_severe:.1f}%</strong> severe global history")
    if recovery_params.weight_global_stress > 0.001:
        weight_labels.append(f"<strong>{100.0 * recovery_params.weight_global_stress:.1f}%</strong> broader stress history")
    if recovery_params.weight_external > 0.001:
        weight_labels.append(f"<strong>{100.0 * recovery_params.weight_external:.1f}%</strong> external-duration prior")
    weight_text = ",\n".join(weight_labels) + "."
    attenuation_ratio_text = (
        "From the conflict-affected reference class, the average post-recovery ratio "
        "(<code>n_total / pre-shock baseline</code>) had medians of "
        f"{100.0 * float(att30.get('conflict_median', 1.0)):.1f}% at 30 days, "
        f"{100.0 * float(att60.get('conflict_median', 1.0)):.1f}% at 60 days, and "
        f"{100.0 * float(att90.get('conflict_median', 1.0)):.1f}% at 90 days."
        if conflict_only_mode
        else (
            "From the fitted reference classes, the average post-recovery ratio "
            "(<code>n_total / pre-shock baseline</code>) was: overall median "
            f"{100.0 * float(att30.get('global_median', 1.0)):.1f}% at 30 days, "
            f"{100.0 * float(att60.get('global_median', 1.0)):.1f}% at 60 days, and "
            f"{100.0 * float(att90.get('global_median', 1.0)):.1f}% at 90 days. "
            "For conflict-affected episodes, medians were "
            f"{100.0 * float(att30.get('conflict_median', 1.0)):.1f}% (30d), "
            f"{100.0 * float(att60.get('conflict_median', 1.0)):.1f}% (60d), and "
            f"{100.0 * float(att90.get('conflict_median', 1.0)):.1f}% (90d)."
        )
    )
    global_severe_table_html = (
        f"""
<h3>Key global severe examples (longest disruptions)</h3>
<table>
  <thead>
    <tr><th>Bottleneck</th><th>Start</th><th>End</th><th>Duration (days)</th><th>Lowest traffic vs pre-shock baseline</th><th>Possible context (news/reports)</th></tr>
  </thead>
  <tbody>
    {render_example_rows(global_severe_examples)}
  </tbody>
</table>
"""
        if (not conflict_only_mode and global_severe_examples)
        else ""
    )

    return f"""<h2>How Starting Probabilities Were Calculated</h2>
<p>
The starting probabilities of this market were calculated by combining the recent Strait of Hormuz traffic data
with historical chokepoint disruption episodes from PortWatch. The model uses reference classes and partial pooling
to keep the estimate data-driven while avoiding overfitting to any single bottleneck.
</p>
{"<p>" + market_target_paragraph + "</p>" if market_target_paragraph else ""}

<h2>Recent traffic collapse in Hormuz</h2>
<p>
PortWatch shows a sharp break in traffic: {pre} calls on <strong>2026-02-28</strong>, then {day1} calls on
<strong>2026-03-01</strong>. The average from {COLLAPSE_START.date().isoformat()} to
{COLLAPSE_END.date().isoformat()} was {week_mean:.1f} calls/day.
</p>
<p>
The linked IMF PortWatch page is useful context, and it does include tanker traffic, but the hover UI does not
show numeric values for every vessel category at once. This market uses the full <code>n_total</code> series
for Strait of Hormuz transit calls, so it is not resolving on a tanker-only subset or a category-specific
hover figure.
</p>

<h2>Reference classes used</h2>
<ul>
  <li>{severe_class_label}: {diagnostics['severe_reference_count']} episodes.</li>
  <li>Conflict-affected chokepoints: {diagnostics['conflict_affected_count']} episodes in Suez, Bosporus, Bab el-Mandeb, and Kerch where traffic fell to 35% or less of baseline.</li>
  {stress_class_li}
  <li>Current Hormuz event: treated as ongoing evidence, with duration at least {known_survival_days} days as of {latest_date.date().isoformat()}.</li>
  {reference_duration_li}
  {conflict_duration_li}
  {stress_duration_li}
  {suez_override_li}
  <li>External-duration prior (non-PortWatch analogues): {len(external_analogues)} analogues with representative durations, added with low weight.</li>
  <li>Model-fit exclusions: {diagnostics.get('excluded_from_fit_count', 0)} episodes from diversion-heavy routes (currently Cape of Good Hope) were excluded from fitted classes to avoid treating rerouting volatility as shutdown risk.</li>
</ul>

<p>
{class_paragraph}
</p>

<p>
{context_paragraph}
</p>

<p>
To avoid overconfidence from repeated episodes at the same chokepoint, we use an effective-sample-size adjustment:
additional episodes from the same chokepoint count as {100.0 * EFFECTIVE_SAMPLE_FRACTION_WITHIN_PORT:.0f}% of a fully independent observation.
{effective_sample_text}
</p>

<h2>Post-recovery attenuation (below 100% after reopening)</h2>
<p>
Historical episodes often reopen before traffic is fully back to the pre-shock baseline. We therefore include a
post-recovery attenuation factor, so "recovered" traffic can stay somewhat below 100% for a while instead of
jumping immediately to full baseline levels.
</p>
<p>
{attenuation_ratio_text}
</p>
<p>
{attenuation_pooling_text} The Hormuz
attenuation factor used by the model is {100.0 * attenuation_params.ratio30_hormuz:.1f}% around 30 days after recovery,
then gradually decays back toward 100% with an estimated half-life of {attenuation_params.half_life_days:.1f} days.
</p>

<p>
Final blend weights in the model:
{weight_text}
</p>

{"<p>" + overlap_note + "</p>" if overlap_note else ""}

<h3>Key conflict-affected examples (longest disruptions)</h3>
<table>
  <thead>
    <tr><th>Bottleneck</th><th>Start</th><th>End</th><th>Duration (days)</th><th>Lowest traffic vs pre-shock baseline</th><th>Possible context (news/reports)</th></tr>
  </thead>
  <tbody>
    {render_example_rows(conflict_examples)}
  </tbody>
</table>

{global_severe_table_html}

{stress_table_html}
{red_sea_table_html}
{additional_analogues_table_html}

<h2>How daily call counts were modeled</h2>
<ul>
  <li>Low-traffic phase (from recent Hormuz collapse): mean {regime.collapse_mean:.2f}, variance {regime.collapse_var:.2f}.</li>
  <li>Recovered phase (from pre-collapse Hormuz period): base mean {regime.baseline_mean:.2f}, variance {regime.baseline_var:.2f}.</li>
  <li>Post-recovery attenuation: recovered mean is multiplied by a factor that starts below 1 and relaxes back toward 1 over time.</li>
</ul>
<p>
For each future day, the model combines those two distributions using the estimated probability that traffic has
recovered by that day, plus the attenuation factor to reflect slower normalization after reopening.
</p>

<h2>Fitting charts</h2>
<p><img src="{{{{asset:hormuz_calls_recent.png}}}}" alt="Recent Strait of Hormuz daily transit calls" /></p>
<p><img src="{{{{asset:recovery_duration_fit.png}}}}" alt="Reference-class recovery duration fit" /></p>

<h2>Estimated chance that traffic has recovered by day N</h2>
<table>
  <thead><tr><th>Days since 2026-03-01</th><th>P(recovered by day)</th></tr></thead>
  <tbody>
    <tr><td>{days_for_preview[0]}</td><td>{pct(rec_preview[0])}</td></tr>
    <tr><td>{days_for_preview[1]}</td><td>{pct(rec_preview[1])}</td></tr>
    <tr><td>{days_for_preview[2]}</td><td>{pct(rec_preview[2])}</td></tr>
    <tr><td>{days_for_preview[3]}</td><td>{pct(rec_preview[3])}</td></tr>
    <tr><td>{days_for_preview[4]}</td><td>{pct(rec_preview[4])}</td></tr>
  </tbody>
</table>

<h2>Estimated recovered-state traffic level by day N</h2>
<table>
  <thead><tr><th>Days since 2026-03-01</th><th>Recovered-state multiplier</th></tr></thead>
  <tbody>
    <tr><td>{days_for_preview[0]}</td><td>{pct(recovered_mult_preview[0])}</td></tr>
    <tr><td>{days_for_preview[1]}</td><td>{pct(recovered_mult_preview[1])}</td></tr>
    <tr><td>{days_for_preview[2]}</td><td>{pct(recovered_mult_preview[2])}</td></tr>
    <tr><td>{days_for_preview[3]}</td><td>{pct(recovered_mult_preview[3])}</td></tr>
    <tr><td>{days_for_preview[4]}</td><td>{pct(recovered_mult_preview[4])}</td></tr>
  </tbody>
</table>

<h2>Data sources</h2>
<ul>
  <li><a href="{PORTWATCH_PAGE_URL}">IMF PortWatch chokepoint page (Strait of Hormuz)</a></li>
  <li><a href="https://www.arcgis.com/sharing/rest/content/items/42132aa4e2fc4d41bdaf9a445f688931?f=json">PortWatch dataset metadata</a></li>
  <li><a href="{PORTWATCH_QUERY_URL}?where=portid%3D%27chokepoint6%27&amp;outFields=date%2Cn_total&amp;orderByFields=date&amp;f=json">PortWatch API query endpoint for Hormuz n_total</a></li>
</ul>

<h2>Methodology files</h2>
<p>
  <a href="{{{{asset:generate_market.py}}}}">Download the Python methodology script</a><br/>
  <a href="{{{{asset:model_summary.json}}}}">Download the model summary JSON</a><br/>
  Data snapshot used: <strong>{latest_date.date().isoformat()} (UTC)</strong>.
</p>
"""


def write_market_csv(
    output_path: Path,
    market_code: str,
    market_title: str,
    mode: str,
    background_filename: str,
    rows: list[dict[str, str]],
    resolution_criteria: str,
    market_end_date: str,
    market_resolution_date: str,
) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    svelte_params = {"scaleType": "linear"}
    if mode == "daily_submarkets":
        svelte_params["timeCadence"] = "monthly"
    elif mode == "weekly_sum":
        svelte_params["timeCadence"] = "weekly"

    x_unit = "calls/week" if mode == "weekly_sum" else "calls/day"
    threshold_count = len(WEEKLY_SUM_THRESHOLDS) if mode == "weekly_sum" else len(COUNT_THRESHOLDS)

    meta_lines = [
        f"# market_code: {market_code}",
        f"# market_title: {market_title}",
        "# market_type: count",
        "# market_visibility: public",
        "# market_status: draft",
        f"# market_budget: {MARKET_BUDGET}",
        f"# market_decay_rate: {MARKET_DECAY_RATE:.3f}",
        f"# market_resolution_criteria: {resolution_criteria}",
        f"# market_x_unit: {x_unit}",
        "# market_number_format: decimal",
        f"# market_end_date: {market_end_date}",
        f"# market_resolution_date: {market_resolution_date}",
        "# market_cumulative: false",
        f"# market_background_info_path: {background_filename}",
        f"# market_svelte_params: {json.dumps(svelte_params, separators=(',', ':'))}",
        f"# generated_note: mode={mode}, reference_class_mode={REFERENCE_CLASS_MODE}, hierarchical_recovery_partial_pooling, thresholds={threshold_count}",
        "",
    ]

    fieldnames = [
        "projection_group",
        "threshold_decimal",
        "threshold_date",
        "initial_probability",
        "label",
        "end_date",
        "decay_rate",
        "status",
    ]

    with output_path.open("w", encoding="utf-8", newline="") as handle:
        handle.write("\n".join(meta_lines))
        writer = csv.DictWriter(handle, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--out", type=Path, default=DEFAULT_CSV_PATH)
    parser.add_argument(
        "--mode",
        choices=["monthly_average", "daily_submarkets", "weekly_sum"],
        default="monthly_average",
        help="Output market structure: monthly-average, weekly-sum, or daily submarkets.",
    )
    parser.add_argument("--market-code", type=str, default=None)
    parser.add_argument("--market-title", type=str, default=None)
    parser.add_argument(
        "--forecast-start",
        type=str,
        default=datetime.now(timezone.utc).date().isoformat(),
        help="Forecast start date in YYYY-MM-DD (UTC).",
    )
    args = parser.parse_args()

    forecast_start = datetime.strptime(args.forecast_start, "%Y-%m-%d").date()
    if forecast_start > FORECAST_END_DATE:
        raise RuntimeError(
            f"forecast-start {forecast_start.isoformat()} is after fixed end date {FORECAST_END_DATE.isoformat()}"
        )

    data = fetch_features(where="1=1", out_fields=["date", "portid", "portname", "n_total"])
    if data.empty:
        raise RuntimeError("No rows returned from PortWatch API")

    episodes = detect_episodes(data)
    ratio_df = build_ratio_frame(data)
    recovery_params, attenuation_params, diagnostics = fit_recovery_model(episodes, ratio_df)

    hormuz = data[data["portid"] == TARGET_PORT_ID].sort_values("date").copy()
    if hormuz.empty:
        raise RuntimeError(f"No rows found for {TARGET_PORT_ID}")

    regime = fit_regime_params(hormuz)

    latest_date = hormuz["date"].max().to_pydatetime()
    known_survival_days = max(1, (latest_date - EVENT_START_DATE).days + 1)

    if args.mode == "monthly_average":
        market_code = args.market_code or DEFAULT_MARKET_CODE_MONTHLY
        market_title = args.market_title or DEFAULT_MARKET_TITLE_MONTHLY
    elif args.mode == "weekly_sum":
        market_code = args.market_code or DEFAULT_MARKET_CODE_WEEKLY
        market_title = args.market_title or DEFAULT_MARKET_TITLE_WEEKLY
    else:
        market_code = args.market_code or DEFAULT_MARKET_CODE_DAILY
        market_title = args.market_title or DEFAULT_MARKET_TITLE_DAILY

    output_csv = args.out.resolve()
    out_dir = output_csv.parent
    background_path = out_dir / "background_info.html"
    summary_path = out_dir / "model_summary.json"
    hormuz_chart_path = out_dir / "hormuz_calls_recent.png"
    recovery_fit_path = out_dir / "recovery_duration_fit.png"

    render_charts(
        hormuz_df=hormuz,
        episodes=episodes,
        known_survival_days=known_survival_days,
        recovery_params=recovery_params,
        out_dir=out_dir,
    )

    if args.mode == "monthly_average":
        rows = build_monthly_average_rows(
            forecast_start=forecast_start,
            hormuz_df=hormuz,
            recovery_params=recovery_params,
            attenuation_params=attenuation_params,
            known_survival_days=known_survival_days,
            regime=regime,
        )
    elif args.mode == "weekly_sum":
        rows = build_weekly_sum_rows(
            forecast_start=forecast_start,
            hormuz_df=hormuz,
            recovery_params=recovery_params,
            attenuation_params=attenuation_params,
            known_survival_days=known_survival_days,
            regime=regime,
        )
    else:
        rows = build_daily_rows(
            forecast_start=forecast_start,
            recovery_params=recovery_params,
            attenuation_params=attenuation_params,
            known_survival_days=known_survival_days,
            regime=regime,
        )

    resolution_criteria = build_resolution_criteria(args.mode)

    if args.mode == "weekly_sum" and rows:
        last_week_end = max(
            datetime.strptime(row["end_date"], "%Y-%m-%dT%H:%M:%SZ").date()
            for row in rows
        )
        market_end_date = day_end_iso(last_week_end)
        market_resolution_date = day_end_iso(last_week_end + timedelta(days=90))
    else:
        market_end_date = day_end_iso(FORECAST_END_DATE)
        market_resolution_date = day_end_iso(FORECAST_END_DATE + timedelta(days=90))

    write_market_csv(
        output_path=output_csv,
        market_code=market_code,
        market_title=market_title,
        mode=args.mode,
        background_filename=background_path.name,
        rows=rows,
        resolution_criteria=resolution_criteria,
        market_end_date=market_end_date,
        market_resolution_date=market_resolution_date,
    )

    background_html = build_background_html(
        mode=args.mode,
        market_title=market_title,
        hormuz_df=hormuz,
        known_survival_days=known_survival_days,
        recovery_params=recovery_params,
        attenuation_params=attenuation_params,
        regime=regime,
        diagnostics=diagnostics,
        forecast_start=forecast_start,
    )
    background_path.write_text(background_html, encoding="utf-8")

    summary = {
        "generated_at_utc": datetime.now(timezone.utc).replace(microsecond=0).isoformat(),
        "script": str(SCRIPT_PATH),
        "source": {
            "dataset_url": PORTWATCH_QUERY_URL,
            "page_url": PORTWATCH_PAGE_URL,
            "target_portid": TARGET_PORT_ID,
            "target_portname": TARGET_PORT_NAME,
            "latest_observation_date_utc": latest_date.date().isoformat(),
        },
        "market": {
            "code": market_code,
            "title": market_title,
            "type": "count",
            "mode": args.mode,
            "forecast_start_date_utc": forecast_start.isoformat(),
            "forecast_end_date_utc": FORECAST_END_DATE.isoformat(),
            "target_days": (FORECAST_END_DATE - forecast_start).days + 1,
            "thresholds": WEEKLY_SUM_THRESHOLDS if args.mode == "weekly_sum" else COUNT_THRESHOLDS,
            "rows": len(rows),
        },
        "artifacts": {
            "hormuz_calls_recent_chart": hormuz_chart_path.name,
            "recovery_duration_fit_chart": recovery_fit_path.name,
            "script_download_asset": SCRIPT_PATH.name,
        },
        "event_definition": {
            "event_start_date_utc": EVENT_START_DATE.date().isoformat(),
            "known_survival_days": known_survival_days,
            "recovery_condition": "7d_mean_n_total >= 60 for 14 consecutive days",
        },
        "reference_classes": diagnostics,
        "recovery_model_params": {
            "mu_global_severe": recovery_params.mu_global_severe,
            "sigma_global_severe": recovery_params.sigma_global_severe,
            "mu_global_stress": recovery_params.mu_global_stress,
            "sigma_global_stress": recovery_params.sigma_global_stress,
            "mu_external": recovery_params.mu_external,
            "sigma_external": recovery_params.sigma_external,
            "mu_conflict": recovery_params.mu_conflict,
            "sigma_conflict": recovery_params.sigma_conflict,
            "mu_hormuz": recovery_params.mu_hormuz,
            "sigma_hormuz": recovery_params.sigma_hormuz,
            "weights": {
                "hormuz": recovery_params.weight_hormuz,
                "conflict_affected": recovery_params.weight_conflict,
                "global_severe": recovery_params.weight_global_severe,
                "global_stress": recovery_params.weight_global_stress,
                "external_analogues": recovery_params.weight_external,
            },
        },
        "post_recovery_attenuation_params": {
            "ratio30_global": attenuation_params.ratio30_global,
            "ratio30_conflict": attenuation_params.ratio30_conflict,
            "ratio30_hormuz": attenuation_params.ratio30_hormuz,
            "ratio60_global": attenuation_params.ratio60_global,
            "ratio60_conflict": attenuation_params.ratio60_conflict,
            "ratio60_hormuz": attenuation_params.ratio60_hormuz,
            "ratio90_global": attenuation_params.ratio90_global,
            "ratio90_conflict": attenuation_params.ratio90_conflict,
            "ratio90_hormuz": attenuation_params.ratio90_hormuz,
            "decay_per_day": attenuation_params.decay_per_day,
            "half_life_days": attenuation_params.half_life_days,
        },
        "regime_params": {
            "baseline_mean": regime.baseline_mean,
            "baseline_var": regime.baseline_var,
            "baseline_size": regime.baseline_size,
            "collapse_mean": regime.collapse_mean,
            "collapse_var": regime.collapse_var,
            "collapse_size": regime.collapse_size,
        },
    }
    summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")

    print(f"Wrote market CSV: {output_csv}")
    print(f"Wrote background HTML: {background_path}")
    print(f"Wrote summary JSON: {summary_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
