#!/usr/bin/env python3
"""Lightweight Bitaxe agent to register miners and forward local stats to the hub."""

import argparse
import asyncio
import ipaddress
import os
import re
import socket
import sys
import textwrap
from typing import Any, Dict, List, Optional

import httpx

DEFAULT_HUB = os.environ.get("BITAXE_HUB_URL", "https://rmt.bitaxermt.xyz")
DEFAULT_INTERVAL = 30
DEFAULT_SCAN_PREFIX = 24
MAX_DISCOVERY_HOSTS = 1024
DISCOVERY_CONCURRENCY = 48
DISCOVERY_TIMEOUT_SECONDS = 1.5


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Bitaxe LAN agent",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=textwrap.dedent(
            """
Auto-discovery runs by default when no --miner entries are provided.
Use multiple --miner entries to predefine the inventory (format: id@ip).
Use --manual to enter miners interactively.
""",
        ),
    )
    parser.add_argument("--hub", default=DEFAULT_HUB, help="Hub base URL (https://host)")
    parser.add_argument("--interval", type=int, default=DEFAULT_INTERVAL, help="Seconds between reports")
    parser.add_argument("--agent-id", help="Agent identifier (defaults to hostname)")
    parser.add_argument(
        "--miner",
        action="append",
        help="Define a miner as id@ip (repeatable)",
    )
    parser.add_argument(
        "--scan-subnet",
        help="Subnet to scan for auto-discovery, for example 192.168.1.0/24",
    )
    parser.add_argument(
        "--no-discovery",
        action="store_true",
        help="Skip auto-discovery (requires --miner or --manual)",
    )
    parser.add_argument(
        "--manual",
        action="store_true",
        help="Prompt for manual miner entries if auto-discovery is skipped or returns no miners",
    )
    return parser.parse_args()


def _safe_id(raw: str, fallback: str) -> str:
    cleaned = re.sub(r"[^a-zA-Z0-9_-]+", "-", raw.strip().lower()).strip("-")
    return cleaned[:64] if cleaned else fallback


def _local_ipv4() -> Optional[str]:
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
            sock.connect(("1.1.1.1", 53))
            ip = sock.getsockname()[0]
            if ip and not ip.startswith("127."):
                return ip
    except Exception:
        pass
    try:
        for ip in socket.gethostbyname_ex(socket.gethostname())[2]:
            if ip and not ip.startswith("127."):
                return ip
    except Exception:
        pass
    return None


def _scan_network(scan_subnet: Optional[str]) -> Optional[ipaddress.IPv4Network]:
    if scan_subnet:
        try:
            network = ipaddress.ip_network(scan_subnet, strict=False)
            if isinstance(network, ipaddress.IPv4Network):
                return network
        except Exception:
            return None
        return None

    local_ip = _local_ipv4()
    if not local_ip:
        return None
    try:
        return ipaddress.ip_network(f"{local_ip}/{DEFAULT_SCAN_PREFIX}", strict=False)
    except Exception:
        return None


def collect_spec_miners(specs: Optional[List[str]]) -> List[Dict[str, str]]:
    miners: List[Dict[str, str]] = []
    seen: set[str] = set()
    if not specs:
        return miners
    for spec in specs:
        if "@" not in spec:
            print(f"Ignoring invalid spec '{spec}'. Expected format id@ip.")
            continue
        miner_id_raw, ip = spec.split("@", 1)
        ip = ip.strip()
        if not ip:
            continue
        fallback_id = f"miner-{ip.replace('.', '-')}"
        miner_id = _safe_id(miner_id_raw, fallback_id)
        if miner_id in seen:
            continue
        seen.add(miner_id)
        miners.append({"id": miner_id, "ip": ip})
    return miners


def collect_manual_miners() -> List[Dict[str, str]]:
    miners: List[Dict[str, str]] = []
    seen: set[str] = set()
    print("\nEnter miner IDs and the LAN IP/hostname for each ASIC.")
    while True:
        miner_id = input("Miner ID (blank to finish): ").strip()
        if not miner_id:
            break
        ip = input("IPv4/hostname: ").strip()
        if not ip:
            print("IP required for each miner. Skipping this entry.")
            continue
        fallback_id = f"miner-{ip.replace('.', '-')}"
        cleaned = _safe_id(miner_id, fallback_id)
        if cleaned in seen:
            print(f"Duplicate miner id '{cleaned}' ignored.")
            continue
        seen.add(cleaned)
        miners.append({"id": cleaned, "ip": ip})
    return miners


def ensure_api_key() -> str:
    key = os.environ.get("BITAXE_API_KEY")
    if key:
        return key.strip()
    while True:
        key = input("Hub API key (X-Api-Key): ").strip()
        if key:
            return key
        print("API key is required to register and report miners.")


def format_hub_url(hub_url: str) -> str:
    return hub_url.rstrip("/")


async def fetch_system_info(
    client: httpx.AsyncClient,
    ip: str,
    *,
    verbose_error: bool = True,
) -> Optional[Dict[str, Any]]:
    try:
        resp = await client.get(f"http://{ip}/api/system/info", timeout=10.0)
        resp.raise_for_status()
        return resp.json()
    except Exception as exc:
        if verbose_error:
            print(f"[!] Failed to query {ip}: {exc}")
        return None


def _bitaxe_like(info: Dict[str, Any]) -> bool:
    if info.get("ASICModel"):
        return True
    if "hashRate" in info:
        return True
    text = " ".join(
        [
            str(info.get("hostname", "")),
            str(info.get("deviceModel", "")),
            str(info.get("boardVersion", "")),
            str(info.get("version", "")),
        ]
    ).lower()
    return "bitaxe" in text or "nerdaxe" in text


def _name_from_info(info: Dict[str, Any], ip: str) -> str:
    candidate = str(info.get("hostname") or info.get("deviceModel") or "").strip()
    return candidate or ip


async def discover_miners(scan_subnet: Optional[str]) -> List[Dict[str, str]]:
    network = _scan_network(scan_subnet)
    if network is None:
        print("[!] Could not determine scan subnet. Use --scan-subnet or enter miners manually.")
        return []

    hosts = [str(ip) for ip in network.hosts()]
    if len(hosts) > MAX_DISCOVERY_HOSTS:
        print(
            f"[!] Discovery subnet {network} has {len(hosts)} hosts. "
            f"Limit is {MAX_DISCOVERY_HOSTS}; use a smaller --scan-subnet."
        )
        return []

    print(f"[*] Auto-discovery scanning {network} ({len(hosts)} hosts)...")
    timeout = httpx.Timeout(
        connect=DISCOVERY_TIMEOUT_SECONDS,
        read=DISCOVERY_TIMEOUT_SECONDS,
        write=DISCOVERY_TIMEOUT_SECONDS,
        pool=DISCOVERY_TIMEOUT_SECONDS,
    )
    sem = asyncio.Semaphore(DISCOVERY_CONCURRENCY)
    discovered: List[Optional[Dict[str, str]]] = []

    async with httpx.AsyncClient(timeout=timeout) as client:

        async def probe(ip: str) -> Optional[Dict[str, str]]:
            async with sem:
                info = await fetch_system_info(client, ip, verbose_error=False)
                if not info or not _bitaxe_like(info):
                    return None
                miner_name = _name_from_info(info, ip)
                fallback_id = f"miner-{ip.replace('.', '-')}"
                miner_id = _safe_id(miner_name, fallback_id)
                return {"id": miner_id, "ip": ip, "name": miner_name}

        discovered = await asyncio.gather(*[probe(ip) for ip in hosts])

    miners: List[Dict[str, str]] = []
    ids: set[str] = set()
    for entry in discovered:
        if not entry:
            continue
        miner_id = entry["id"]
        if miner_id in ids:
            suffix = entry["ip"].replace(".", "-")
            miner_id = _safe_id(f"{miner_id}-{suffix}", f"miner-{suffix}")
            entry["id"] = miner_id
        ids.add(miner_id)
        miners.append(entry)

    miners.sort(key=lambda m: int(ipaddress.ip_address(m["ip"])))
    if miners:
        print("[*] Found miners:")
        for miner in miners:
            print(f"    - {miner['id']} @ {miner['ip']}")
    else:
        print("[!] No Bitaxe-compatible miners found with auto-discovery.")
    return miners


def summarize_miners(miners: List[Dict[str, str]]) -> str:
    return " | ".join(f"{miner['id']}@{miner['ip']}" for miner in miners)


async def register_agent(
    client: httpx.AsyncClient,
    hub_url: str,
    api_key: str,
    agent_id: str,
    miner_ids: List[str],
) -> None:
    payload = {"agentId": agent_id, "minerIds": miner_ids}
    headers = {"X-Api-Key": api_key, "Content-Type": "application/json"}
    resp = await client.post(f"{hub_url}/api/agents/register", json=payload, headers=headers, timeout=15.0)
    resp.raise_for_status()
    data = resp.json()
    assigned = data.get("assignedMiners", [])
    print(f"[*] Registered agent '{agent_id}' with {len(assigned)} miners")


async def report_miner(
    client: httpx.AsyncClient,
    hub_url: str,
    api_key: str,
    agent_id: str,
    miner: Dict[str, str],
    info: Dict[str, Any],
) -> None:
    payload = {
        "agentId": agent_id,
        "minerId": miner["id"],
        "minerName": miner.get("name") or miner["id"],
        "ip": miner["ip"],
        "info": info,
    }
    headers = {"X-Api-Key": api_key, "Content-Type": "application/json"}
    resp = await client.post(f"{hub_url}/api/agents/report", json=payload, headers=headers, timeout=15.0)
    resp.raise_for_status()
    print(f"[+] Reported {miner['id']} ({miner['ip']})")


async def main() -> None:
    args = parse_args()
    hub_url = format_hub_url(args.hub)
    miners = collect_spec_miners(args.miner)
    if not miners and not args.no_discovery:
        miners = await discover_miners(args.scan_subnet)
    if not miners and args.manual:
        miners = collect_manual_miners()
    if not miners:
        print("No miners discovered. Ensure Bitaxes are reachable on your LAN or pass --scan-subnet.")
        print("Use --miner id@ip for fixed entries, or add --manual for interactive entry.")
        return

    api_key = ensure_api_key()
    agent_id = args.agent_id or os.environ.get("BITAXE_AGENT_ID") or socket.gethostname()
    interval = max(5, args.interval)
    print("\nBitaxe agent starting:")
    print(f"  hub: {hub_url}")
    print(f"  agent id: {agent_id}")
    print(f"  miners: {summarize_miners(miners)}")
    print(f"  interval: {interval}s\n")

    async with httpx.AsyncClient() as client:
        await register_agent(client, hub_url, api_key, agent_id, [m["id"] for m in miners])
        try:
            while True:
                for miner in miners:
                    info = await fetch_system_info(client, miner["ip"])
                    if not info:
                        continue
                    try:
                        await report_miner(client, hub_url, api_key, agent_id, miner, info)
                    except Exception as exc:
                        print(f"[!] Hub report failed for {miner['id']}: {exc}")
                await asyncio.sleep(interval)
        except KeyboardInterrupt:
            print("\nShutting down agent.")


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except Exception as exc:
        print(f"Fatal error: {exc}")
        sys.exit(1)
