Files
OSU-Public-Wi-Fi-Login/osu_wifi_login/config.py

390 lines
12 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
class ConfigError(ValueError):
"""Raised when the YAML configuration is invalid."""
@dataclass(slots=True)
class ConnectivityCheck:
url: str
expected_status: int
expected_text: str | None = None
require_final_url_match: bool = True
allow_redirects: bool = False
@dataclass(slots=True)
class NetworkProfile:
key: str
ssid: str
profile_name: str
requires_portal_login: bool
randomize_mac: bool
restore_hardware_mac: bool = False
auto_create_open_profile: bool = False
mac_refresh_hours: float | None = None
@dataclass(slots=True)
class MonitorConfig:
check_interval_seconds: int = 15
connect_retry_cooldown_seconds: int = 20
disconnect_grace_seconds: int = 45
adapter_reset_cooldown_seconds: int = 120
reconnect_wait_seconds: int = 12
connection_timeout_seconds: int = 25
@dataclass(slots=True)
class SeleniumConfig:
headless: bool = True
page_load_timeout_seconds: int = 20
element_timeout_seconds: int = 12
max_login_retries: int = 3
@dataclass(slots=True)
class PortalConfig:
trigger_urls: list[str] = field(
default_factory=lambda: [
"http://captive.apple.com/",
"http://www.msftconnecttest.com/redirect",
"http://detectportal.firefox.com/",
],
)
accept_terms_name: str = "visitor_accept_terms"
login_button_xpath: str = "//input[@type='submit' and @value='Log In']"
@dataclass(slots=True)
class LoggingConfig:
level: str = "INFO"
@dataclass(slots=True)
class AppConfig:
selected_network: str
networks: dict[str, NetworkProfile]
connectivity_checks: list[ConnectivityCheck]
monitor: MonitorConfig = field(default_factory=MonitorConfig)
selenium: SeleniumConfig = field(default_factory=SeleniumConfig)
portal: PortalConfig = field(default_factory=PortalConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
wifi_interface_name: str | None = None
allow_unsafe_mac_changes: bool = False
randomize_mac_on_start: bool = True
@property
def active_network(self) -> NetworkProfile:
try:
return self.networks[self.selected_network]
except KeyError as exc:
raise ConfigError(
f"selected_network '{self.selected_network}' is not defined in networks"
) from exc
def load_config(path: str | Path) -> AppConfig:
config_path = Path(path)
if not config_path.exists():
raise ConfigError(f"Config file not found: {config_path}")
raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
if not isinstance(raw, dict):
raise ConfigError("The top-level YAML object must be a mapping")
selected_network = _require_str(raw, "selected_network")
networks_raw = raw.get("networks")
if not isinstance(networks_raw, dict) or not networks_raw:
raise ConfigError("networks must be a non-empty mapping")
networks: dict[str, NetworkProfile] = {}
for key, value in networks_raw.items():
if not isinstance(value, dict):
raise ConfigError(f"network '{key}' must be a mapping")
networks[key] = NetworkProfile(
key=key,
ssid=_require_str(value, "ssid", context=f"networks.{key}"),
profile_name=str(value.get("profile_name") or value.get("ssid") or "").strip(),
requires_portal_login=_require_bool(
value,
"requires_portal_login",
context=f"networks.{key}",
),
randomize_mac=_require_bool(value, "randomize_mac", context=f"networks.{key}"),
restore_hardware_mac=_require_bool(
value,
"restore_hardware_mac",
default=False,
context=f"networks.{key}",
),
auto_create_open_profile=_require_bool(
value,
"auto_create_open_profile",
default=False,
context=f"networks.{key}",
),
mac_refresh_hours=_optional_float(
value,
"mac_refresh_hours",
context=f"networks.{key}",
),
)
if not networks[key].profile_name:
raise ConfigError(f"networks.{key}.profile_name cannot be empty")
if networks[key].randomize_mac and networks[key].mac_refresh_hours is not None:
if networks[key].mac_refresh_hours <= 0:
raise ConfigError(f"networks.{key}.mac_refresh_hours must be greater than 0")
if selected_network not in networks:
raise ConfigError(f"selected_network '{selected_network}' is not defined in networks")
checks_raw = raw.get("connectivity_checks")
if not isinstance(checks_raw, list) or not checks_raw:
raise ConfigError("connectivity_checks must be a non-empty list")
connectivity_checks = [
ConnectivityCheck(
url=_require_str(item, "url", context="connectivity_checks[]"),
expected_status=_require_int(item, "expected_status", context="connectivity_checks[]"),
expected_text=_optional_str(item, "expected_text"),
require_final_url_match=_require_bool(
item,
"require_final_url_match",
default=True,
context="connectivity_checks[]",
),
allow_redirects=_require_bool(
item,
"allow_redirects",
default=False,
context="connectivity_checks[]",
),
)
for item in checks_raw
]
monitor_raw = raw.get("monitor", {})
if not isinstance(monitor_raw, dict):
raise ConfigError("monitor must be a mapping")
monitor = MonitorConfig(
check_interval_seconds=_require_int(
monitor_raw,
"check_interval_seconds",
default=15,
context="monitor",
),
connect_retry_cooldown_seconds=_require_int(
monitor_raw,
"connect_retry_cooldown_seconds",
default=20,
context="monitor",
),
disconnect_grace_seconds=_require_int(
monitor_raw,
"disconnect_grace_seconds",
default=45,
context="monitor",
),
adapter_reset_cooldown_seconds=_require_int(
monitor_raw,
"adapter_reset_cooldown_seconds",
default=120,
context="monitor",
),
reconnect_wait_seconds=_require_int(
monitor_raw,
"reconnect_wait_seconds",
default=12,
context="monitor",
),
connection_timeout_seconds=_require_int(
monitor_raw,
"connection_timeout_seconds",
default=25,
context="monitor",
),
)
for field_name, value in (
("check_interval_seconds", monitor.check_interval_seconds),
("connect_retry_cooldown_seconds", monitor.connect_retry_cooldown_seconds),
("disconnect_grace_seconds", monitor.disconnect_grace_seconds),
("adapter_reset_cooldown_seconds", monitor.adapter_reset_cooldown_seconds),
("reconnect_wait_seconds", monitor.reconnect_wait_seconds),
("connection_timeout_seconds", monitor.connection_timeout_seconds),
):
if value <= 0:
raise ConfigError(f"monitor.{field_name} must be greater than 0")
selenium_raw = raw.get("selenium", {})
if not isinstance(selenium_raw, dict):
raise ConfigError("selenium must be a mapping")
selenium = SeleniumConfig(
headless=_require_bool(selenium_raw, "headless", default=True, context="selenium"),
page_load_timeout_seconds=_require_int(
selenium_raw,
"page_load_timeout_seconds",
default=20,
context="selenium",
),
element_timeout_seconds=_require_int(
selenium_raw,
"element_timeout_seconds",
default=12,
context="selenium",
),
max_login_retries=_require_int(
selenium_raw,
"max_login_retries",
default=3,
context="selenium",
),
)
portal_raw = raw.get("portal", {})
if not isinstance(portal_raw, dict):
raise ConfigError("portal must be a mapping")
portal = PortalConfig(
trigger_urls=_require_string_list(
portal_raw,
"trigger_urls",
default=[
"http://captive.apple.com/",
"http://www.msftconnecttest.com/redirect",
"http://detectportal.firefox.com/",
],
context="portal",
),
accept_terms_name=_require_str(
portal_raw,
"accept_terms_name",
default="visitor_accept_terms",
context="portal",
),
login_button_xpath=_require_str(
portal_raw,
"login_button_xpath",
default="//input[@type='submit' and @value='Log In']",
context="portal",
),
)
logging_raw = raw.get("logging", {})
if not isinstance(logging_raw, dict):
raise ConfigError("logging must be a mapping")
logging_config = LoggingConfig(
level=_require_str(logging_raw, "level", default="INFO", context="logging").upper(),
)
return AppConfig(
selected_network=selected_network,
networks=networks,
connectivity_checks=connectivity_checks,
monitor=monitor,
selenium=selenium,
portal=portal,
logging=logging_config,
wifi_interface_name=_optional_str(raw, "wifi_interface_name"),
allow_unsafe_mac_changes=_require_bool(
raw,
"allow_unsafe_mac_changes",
default=False,
),
randomize_mac_on_start=_require_bool(
raw,
"randomize_mac_on_start",
default=True,
),
)
def _require_str(
data: dict[str, Any],
key: str,
default: str | None = None,
*,
context: str = "config",
) -> str:
value = data.get(key, default)
if not isinstance(value, str) or not value.strip():
raise ConfigError(f"{context}.{key} must be a non-empty string")
return value.strip()
def _optional_str(data: dict[str, Any], key: str) -> str | None:
value = data.get(key)
if value is None:
return None
if not isinstance(value, str):
raise ConfigError(f"config.{key} must be a string when set")
stripped = value.strip()
return stripped or None
def _require_bool(
data: dict[str, Any],
key: str,
default: bool | None = None,
*,
context: str = "config",
) -> bool:
value = data.get(key, default)
if not isinstance(value, bool):
raise ConfigError(f"{context}.{key} must be a boolean")
return value
def _require_int(
data: dict[str, Any],
key: str,
default: int | None = None,
*,
context: str = "config",
) -> int:
value = data.get(key, default)
if isinstance(value, bool) or not isinstance(value, int):
raise ConfigError(f"{context}.{key} must be an integer")
return value
def _optional_float(
data: dict[str, Any],
key: str,
*,
context: str = "config",
) -> float | None:
value = data.get(key)
if value is None:
return None
if isinstance(value, bool):
raise ConfigError(f"{context}.{key} must be a number when set")
if isinstance(value, int):
return float(value)
if not isinstance(value, float):
raise ConfigError(f"{context}.{key} must be a number when set")
return value
def _require_string_list(
data: dict[str, Any],
key: str,
default: list[str] | None = None,
*,
context: str = "config",
) -> list[str]:
value = data.get(key, default)
if not isinstance(value, list) or not value:
raise ConfigError(f"{context}.{key} must be a non-empty list")
normalized: list[str] = []
for item in value:
if not isinstance(item, str) or not item.strip():
raise ConfigError(f"{context}.{key} must contain non-empty strings")
normalized.append(item.strip())
return normalized