#!/usr/bin/env python3

import argparse
import csv
import datetime as _dt
import statistics
import subprocess
import sys
import time
from pathlib import Path


def warn_if_benchmark_environment_is_noisy() -> None:
    governor_path = Path("/sys/devices/system/cpu/cpu2/cpufreq/scaling_governor")
    if governor_path.exists():
        governor = governor_path.read_text(encoding="utf-8").strip()
        if governor != "performance":
            print(
                f"warning: cpu2 scaling governor is '{governor}', expected 'performance'",
                file=sys.stderr,
            )

    randomize_va_space = Path("/proc/sys/kernel/randomize_va_space")
    if randomize_va_space.exists():
        value = randomize_va_space.read_text(encoding="utf-8").strip()
        if value != "0":
            print(
                f"warning: kernel.randomize_va_space is '{value}', expected '0'",
                file=sys.stderr,
            )


def _parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run TLMBoy benchmarks across toolchains")
    parser.add_argument(
        "runs",
        nargs="?",
        type=int,
        default=11,
        help="Number of runs per binary (default: 11)",
    )
    parser.add_argument(
        "--rom",
        type=Path,
        default=Path.home() / "Games" / "flappyboy.gb",
        help="ROM path (default: ~/Games/flappyboy.gb)",
    )
    parser.add_argument(
        "--cpu",
        default="2",
        help="CPU core to pin via taskset -c (default: 2)",
    )
    parser.add_argument(
        "--max-cycles",
        type=int,
        default=15_000_000,
        help="Max cycles passed to the emulator (default: 15000000)",
    )
    parser.add_argument(
        "--csv",
        type=Path,
        default=None,
        help="Output CSV path (default: bm/benchmark_results.csv)",
    )
    return parser.parse_args()


def main() -> int:
    args = _parse_args()
    runs: int = args.runs
    rom: Path = args.rom
    script_dir = Path(__file__).resolve().parent
    csv_path: Path = args.csv if args.csv is not None else (script_dir / "benchmark_results.csv")

    toolchains = [
        "clang18",
        "clang19",
        "clang20",
        "gcc13",
        "gcc14",
        "gcc15",
    ]

    warn_if_benchmark_environment_is_noisy()

    rows: list[dict[str, object]] = []
    started_at = _dt.datetime.now(tz=_dt.timezone.utc).isoformat()

    exit_code = 0
    try:
        for toolchain in toolchains:
            binary_path = script_dir / f"tlmboy_{toolchain}"
            if not binary_path.exists():
                print(f"error: missing benchmark binary: {binary_path}", file=sys.stderr)
                exit_code = 2
                break

            times_ms: list[float] = []
            for run_index in range(1, runs + 1):
                start = time.perf_counter_ns()
                subprocess.run(
                    [
                        "taskset",
                        "-c",
                        str(args.cpu),
                        str(binary_path),
                        "-r",
                        str(rom),
                        "--fps-cap=-1",
                        f"--max-cycles={args.max_cycles}",
                        "--headless",
                    ],
                    check=True,
                    stdout=subprocess.DEVNULL,
                    stderr=subprocess.PIPE,
                    text=True,
                )
                elapsed_ms = (time.perf_counter_ns() - start) / 1e6
                times_ms.append(elapsed_ms)
                rows.append(
                    {
                        "started_at_utc": started_at,
                        "toolchain": toolchain,
                        "binary": binary_path.name,
                        "run": run_index,
                        "ms": elapsed_ms,
                        "cpu": str(args.cpu),
                        "rom": str(rom),
                        "max_cycles": int(args.max_cycles),
                    }
                )

            print()
            print(f"{binary_path.name}:")
            print(f"  median: {statistics.median(times_ms)} ms")
            print(f"  lowest: {min(times_ms)} ms")
            print(f"  highest: {max(times_ms)} ms")
    except subprocess.CalledProcessError as e:
        print(f"error: benchmark command failed with exit code {e.returncode}", file=sys.stderr)
        if e.stderr:
            print(e.stderr, file=sys.stderr, end="" if e.stderr.endswith("\n") else "\n")
        exit_code = int(e.returncode) if e.returncode is not None else 1
    except KeyboardInterrupt:
        print("\ninterrupted", file=sys.stderr)
        exit_code = 130
    finally:
        if rows:
            csv_path.parent.mkdir(parents=True, exist_ok=True)
            with csv_path.open("w", encoding="utf-8", newline="") as f:
                writer = csv.DictWriter(
                    f,
                    fieldnames=[
                        "started_at_utc",
                        "toolchain",
                        "binary",
                        "run",
                        "ms",
                        "cpu",
                        "rom",
                        "max_cycles",
                    ],
                )
                writer.writeheader()
                writer.writerows(rows)

            print()
            print(f"wrote: {csv_path}")

    return exit_code


raise SystemExit(main())

