from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Any import yaml class ConfigError(ValueError): """Raised when the YAML configuration is invalid.""" @dataclass(slots=True) class ConnectivityCheck: url: str expected_status: int expected_text: str | None = None require_final_url_match: bool = True allow_redirects: bool = False @dataclass(slots=True) class NetworkProfile: key: str ssid: str profile_name: str requires_portal_login: bool randomize_mac: bool restore_hardware_mac: bool = False auto_create_open_profile: bool = False mac_refresh_hours: float | None = None @dataclass(slots=True) class MonitorConfig: check_interval_seconds: int = 15 connect_retry_cooldown_seconds: int = 20 disconnect_grace_seconds: int = 45 adapter_reset_cooldown_seconds: int = 120 reconnect_wait_seconds: int = 12 connection_timeout_seconds: int = 25 @dataclass(slots=True) class SeleniumConfig: headless: bool = True page_load_timeout_seconds: int = 20 element_timeout_seconds: int = 12 max_login_retries: int = 3 @dataclass(slots=True) class PortalConfig: trigger_urls: list[str] = field( default_factory=lambda: [ "http://captive.apple.com/", "http://www.msftconnecttest.com/redirect", "http://detectportal.firefox.com/", ], ) accept_terms_name: str = "visitor_accept_terms" login_button_xpath: str = "//input[@type='submit' and @value='Log In']" @dataclass(slots=True) class LoggingConfig: level: str = "INFO" @dataclass(slots=True) class AppConfig: selected_network: str networks: dict[str, NetworkProfile] connectivity_checks: list[ConnectivityCheck] monitor: MonitorConfig = field(default_factory=MonitorConfig) selenium: SeleniumConfig = field(default_factory=SeleniumConfig) portal: PortalConfig = field(default_factory=PortalConfig) logging: LoggingConfig = field(default_factory=LoggingConfig) wifi_interface_name: str | None = None allow_unsafe_mac_changes: bool = False randomize_mac_on_start: bool = True @property def active_network(self) -> NetworkProfile: try: return self.networks[self.selected_network] except KeyError as exc: raise ConfigError( f"selected_network '{self.selected_network}' is not defined in networks" ) from exc def load_config(path: str | Path) -> AppConfig: config_path = Path(path) if not config_path.exists(): raise ConfigError(f"Config file not found: {config_path}") raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {} if not isinstance(raw, dict): raise ConfigError("The top-level YAML object must be a mapping") selected_network = _require_str(raw, "selected_network") networks_raw = raw.get("networks") if not isinstance(networks_raw, dict) or not networks_raw: raise ConfigError("networks must be a non-empty mapping") networks: dict[str, NetworkProfile] = {} for key, value in networks_raw.items(): if not isinstance(value, dict): raise ConfigError(f"network '{key}' must be a mapping") networks[key] = NetworkProfile( key=key, ssid=_require_str(value, "ssid", context=f"networks.{key}"), profile_name=str(value.get("profile_name") or value.get("ssid") or "").strip(), requires_portal_login=_require_bool( value, "requires_portal_login", context=f"networks.{key}", ), randomize_mac=_require_bool(value, "randomize_mac", context=f"networks.{key}"), restore_hardware_mac=_require_bool( value, "restore_hardware_mac", default=False, context=f"networks.{key}", ), auto_create_open_profile=_require_bool( value, "auto_create_open_profile", default=False, context=f"networks.{key}", ), mac_refresh_hours=_optional_float( value, "mac_refresh_hours", context=f"networks.{key}", ), ) if not networks[key].profile_name: raise ConfigError(f"networks.{key}.profile_name cannot be empty") if networks[key].randomize_mac and networks[key].mac_refresh_hours is not None: if networks[key].mac_refresh_hours <= 0: raise ConfigError(f"networks.{key}.mac_refresh_hours must be greater than 0") if selected_network not in networks: raise ConfigError(f"selected_network '{selected_network}' is not defined in networks") checks_raw = raw.get("connectivity_checks") if not isinstance(checks_raw, list) or not checks_raw: raise ConfigError("connectivity_checks must be a non-empty list") connectivity_checks = [ ConnectivityCheck( url=_require_str(item, "url", context="connectivity_checks[]"), expected_status=_require_int(item, "expected_status", context="connectivity_checks[]"), expected_text=_optional_str(item, "expected_text"), require_final_url_match=_require_bool( item, "require_final_url_match", default=True, context="connectivity_checks[]", ), allow_redirects=_require_bool( item, "allow_redirects", default=False, context="connectivity_checks[]", ), ) for item in checks_raw ] monitor_raw = raw.get("monitor", {}) if not isinstance(monitor_raw, dict): raise ConfigError("monitor must be a mapping") monitor = MonitorConfig( check_interval_seconds=_require_int( monitor_raw, "check_interval_seconds", default=15, context="monitor", ), connect_retry_cooldown_seconds=_require_int( monitor_raw, "connect_retry_cooldown_seconds", default=20, context="monitor", ), disconnect_grace_seconds=_require_int( monitor_raw, "disconnect_grace_seconds", default=45, context="monitor", ), adapter_reset_cooldown_seconds=_require_int( monitor_raw, "adapter_reset_cooldown_seconds", default=120, context="monitor", ), reconnect_wait_seconds=_require_int( monitor_raw, "reconnect_wait_seconds", default=12, context="monitor", ), connection_timeout_seconds=_require_int( monitor_raw, "connection_timeout_seconds", default=25, context="monitor", ), ) for field_name, value in ( ("check_interval_seconds", monitor.check_interval_seconds), ("connect_retry_cooldown_seconds", monitor.connect_retry_cooldown_seconds), ("disconnect_grace_seconds", monitor.disconnect_grace_seconds), ("adapter_reset_cooldown_seconds", monitor.adapter_reset_cooldown_seconds), ("reconnect_wait_seconds", monitor.reconnect_wait_seconds), ("connection_timeout_seconds", monitor.connection_timeout_seconds), ): if value <= 0: raise ConfigError(f"monitor.{field_name} must be greater than 0") selenium_raw = raw.get("selenium", {}) if not isinstance(selenium_raw, dict): raise ConfigError("selenium must be a mapping") selenium = SeleniumConfig( headless=_require_bool(selenium_raw, "headless", default=True, context="selenium"), page_load_timeout_seconds=_require_int( selenium_raw, "page_load_timeout_seconds", default=20, context="selenium", ), element_timeout_seconds=_require_int( selenium_raw, "element_timeout_seconds", default=12, context="selenium", ), max_login_retries=_require_int( selenium_raw, "max_login_retries", default=3, context="selenium", ), ) portal_raw = raw.get("portal", {}) if not isinstance(portal_raw, dict): raise ConfigError("portal must be a mapping") portal = PortalConfig( trigger_urls=_require_string_list( portal_raw, "trigger_urls", default=[ "http://captive.apple.com/", "http://www.msftconnecttest.com/redirect", "http://detectportal.firefox.com/", ], context="portal", ), accept_terms_name=_require_str( portal_raw, "accept_terms_name", default="visitor_accept_terms", context="portal", ), login_button_xpath=_require_str( portal_raw, "login_button_xpath", default="//input[@type='submit' and @value='Log In']", context="portal", ), ) logging_raw = raw.get("logging", {}) if not isinstance(logging_raw, dict): raise ConfigError("logging must be a mapping") logging_config = LoggingConfig( level=_require_str(logging_raw, "level", default="INFO", context="logging").upper(), ) return AppConfig( selected_network=selected_network, networks=networks, connectivity_checks=connectivity_checks, monitor=monitor, selenium=selenium, portal=portal, logging=logging_config, wifi_interface_name=_optional_str(raw, "wifi_interface_name"), allow_unsafe_mac_changes=_require_bool( raw, "allow_unsafe_mac_changes", default=False, ), randomize_mac_on_start=_require_bool( raw, "randomize_mac_on_start", default=True, ), ) def _require_str( data: dict[str, Any], key: str, default: str | None = None, *, context: str = "config", ) -> str: value = data.get(key, default) if not isinstance(value, str) or not value.strip(): raise ConfigError(f"{context}.{key} must be a non-empty string") return value.strip() def _optional_str(data: dict[str, Any], key: str) -> str | None: value = data.get(key) if value is None: return None if not isinstance(value, str): raise ConfigError(f"config.{key} must be a string when set") stripped = value.strip() return stripped or None def _require_bool( data: dict[str, Any], key: str, default: bool | None = None, *, context: str = "config", ) -> bool: value = data.get(key, default) if not isinstance(value, bool): raise ConfigError(f"{context}.{key} must be a boolean") return value def _require_int( data: dict[str, Any], key: str, default: int | None = None, *, context: str = "config", ) -> int: value = data.get(key, default) if isinstance(value, bool) or not isinstance(value, int): raise ConfigError(f"{context}.{key} must be an integer") return value def _optional_float( data: dict[str, Any], key: str, *, context: str = "config", ) -> float | None: value = data.get(key) if value is None: return None if isinstance(value, bool): raise ConfigError(f"{context}.{key} must be a number when set") if isinstance(value, int): return float(value) if not isinstance(value, float): raise ConfigError(f"{context}.{key} must be a number when set") return value def _require_string_list( data: dict[str, Any], key: str, default: list[str] | None = None, *, context: str = "config", ) -> list[str]: value = data.get(key, default) if not isinstance(value, list) or not value: raise ConfigError(f"{context}.{key} must be a non-empty list") normalized: list[str] = [] for item in value: if not isinstance(item, str) or not item.strip(): raise ConfigError(f"{context}.{key} must contain non-empty strings") normalized.append(item.strip()) return normalized