from __future__ import annotations import locale import random import re import subprocess import tempfile import time import winreg from dataclasses import dataclass from pathlib import Path from typing import Iterable import requests from .config import ConnectivityCheck class WifiCommandError(RuntimeError): """Raised when a Windows networking command fails.""" @dataclass(slots=True) class WifiStatus: interface_name: str | None description: str | None state: str | None ssid: str | None @property def is_connected(self) -> bool: state = (self.state or "").lower() return "connected" in state or "已连接" in state or "已連線" in state def test_connectivity(checks: Iterable[ConnectivityCheck], timeout_seconds: int = 3) -> bool: for check in checks: try: response = requests.get( check.url, timeout=timeout_seconds, allow_redirects=check.allow_redirects, ) except requests.RequestException: continue if response.status_code != check.expected_status: continue if check.require_final_url_match and _normalize_url(response.url) != _normalize_url(check.url): continue if check.expected_text is not None and check.expected_text not in response.text: continue if check.expected_status == 204 and response.content: continue return True return False def get_wifi_status(preferred_interface_name: str | None = None) -> WifiStatus: output = run_command(["netsh", "wlan", "show", "interfaces"]) blocks = [block.strip() for block in re.split(r"(?:\r?\n){2,}", output) if block.strip()] parsed_blocks = [_parse_key_value_block(block) for block in blocks] candidate = None if preferred_interface_name: preferred = preferred_interface_name.casefold() for block in parsed_blocks: if (block.get("name") or "").casefold() == preferred: candidate = block break if candidate is None: for block in parsed_blocks: state = (block.get("state") or "").lower() if "connected" in state or "disconnected" in state or "已" in state: candidate = block break if candidate is None: return WifiStatus(interface_name=None, description=None, state=None, ssid=None) return WifiStatus( interface_name=candidate.get("name"), description=candidate.get("description"), state=candidate.get("state"), ssid=candidate.get("ssid"), ) def restart_adapter(interface_name: str) -> None: run_command(["netsh", "interface", "set", "interface", interface_name, "disable"]) time.sleep(2) run_command(["netsh", "interface", "set", "interface", interface_name, "enable"]) def disconnect_wifi(interface_name: str | None = None) -> None: command = ["netsh", "wlan", "disconnect"] if interface_name: command.append(f"interface={interface_name}") run_command(command) def connect_wifi( profile_name: str, interface_name: str | None = None, *, ssid: str | None = None, auto_create_open_profile: bool = False, ) -> None: command = ["netsh", "wlan", "connect", f"name={profile_name}"] if ssid: command.append(f"ssid={ssid}") if interface_name: command.append(f"interface={interface_name}") try: run_command(command) except WifiCommandError as exc: if auto_create_open_profile and _is_missing_profile_error(str(exc), profile_name): create_open_wifi_profile( profile_name=profile_name, ssid=ssid or profile_name, interface_name=interface_name, ) run_command(command) return if _is_missing_profile_error(str(exc), profile_name): available_profiles = list_wifi_profiles(interface_name) profiles_display = ", ".join(available_profiles) if available_profiles else "" target = interface_name or "" raise WifiCommandError( f'Windows WLAN profile "{profile_name}" is missing on interface {target}. ' f"Available profiles: {profiles_display}", ) from exc raise def list_wifi_profiles(interface_name: str | None = None) -> list[str]: command = ["netsh", "wlan", "show", "profiles"] if interface_name: command.append(f"interface={interface_name}") output = run_command(command) profiles: list[str] = [] for raw_line in output.splitlines(): line = raw_line.strip() if not line or ":" not in line: continue key, value = line.split(":", 1) key_lower = key.casefold() if "profile" not in key_lower and "配置文件" not in key_lower and "設定檔" not in key_lower: continue name = value.strip() if not name or name == "": continue profiles.append(name) return profiles def create_open_wifi_profile( profile_name: str, ssid: str, interface_name: str | None = None, *, connection_mode: str = "auto", ) -> None: profile_xml = _build_open_profile_xml(profile_name, ssid, connection_mode=connection_mode) temp_path: Path | None = None try: with tempfile.NamedTemporaryFile( mode="w", suffix=".xml", delete=False, encoding="utf-8", ) as handle: handle.write(profile_xml) temp_path = Path(handle.name) command = ["netsh", "wlan", "add", "profile", f"filename={temp_path}", "user=current"] if interface_name: command.append(f"interface={interface_name}") run_command(command) except WifiCommandError as exc: raise WifiCommandError( f'Failed to create an open Wi-Fi profile for SSID "{ssid}": {exc}', ) from exc finally: if temp_path is not None: temp_path.unlink(missing_ok=True) def wait_for_ssid( target_ssid: str, timeout_seconds: int, preferred_interface_name: str | None = None, ) -> WifiStatus: deadline = time.monotonic() + timeout_seconds while time.monotonic() < deadline: status = get_wifi_status(preferred_interface_name) if status.is_connected and status.ssid == target_ssid: return status time.sleep(1) return get_wifi_status(preferred_interface_name) def ensure_hardware_mac(adapter_locator: str | None = None) -> bool: adapter_name = adapter_locator or _require_adapter_name() reg_path = find_adapter_registry_key(adapter_name) if not reg_path: raise WifiCommandError(f"Could not find registry entry for adapter '{adapter_name}'") try: key = winreg.OpenKey( winreg.HKEY_LOCAL_MACHINE, reg_path, 0, winreg.KEY_SET_VALUE, ) try: winreg.DeleteValue(key, "NetworkAddress") except FileNotFoundError: return False finally: winreg.CloseKey(key) except PermissionError as exc: raise WifiCommandError("Administrator privileges are required to restore hardware MAC") from exc return True def randomize_mac(adapter_locator: str | None = None) -> str: adapter_name = adapter_locator or _require_adapter_name() reg_path = find_adapter_registry_key(adapter_name) if not reg_path: raise WifiCommandError(f"Could not find registry entry for adapter '{adapter_name}'") mac_bytes = [random.randint(0x00, 0xFF) for _ in range(6)] mac_bytes[0] = (mac_bytes[0] & 0xFC) | 0x02 mac_compact = "".join(f"{part:02X}" for part in mac_bytes) try: key = winreg.OpenKey( winreg.HKEY_LOCAL_MACHINE, reg_path, 0, winreg.KEY_SET_VALUE, ) winreg.SetValueEx(key, "NetworkAddress", 0, winreg.REG_SZ, mac_compact) winreg.CloseKey(key) except PermissionError as exc: raise WifiCommandError("Administrator privileges are required to randomize the MAC address") from exc return ":".join(f"{part:02X}" for part in mac_bytes) def find_adapter_registry_key(adapter_name: str) -> str | None: base_path = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e972-e325-11ce-bfc1-08002be10318}" try: base_key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, base_path) except OSError: return None try: index = 0 adapter_name_folded = adapter_name.casefold() while True: try: subkey_name = winreg.EnumKey(base_key, index) except OSError: return None index += 1 full_path = f"{base_path}\\{subkey_name}" try: subkey = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, full_path) except OSError: continue try: for field_name in ("DriverDesc", "NetCfgInstanceId", "ComponentId"): try: value, _ = winreg.QueryValueEx(subkey, field_name) except OSError: continue if adapter_name_folded in str(value).casefold(): return full_path try: connection_key = winreg.OpenKey(subkey, "Connection") except OSError: connection_key = None if connection_key is not None: try: connection_name, _ = winreg.QueryValueEx(connection_key, "Name") if adapter_name_folded in str(connection_name).casefold(): return full_path except OSError: pass finally: winreg.CloseKey(connection_key) finally: winreg.CloseKey(subkey) finally: winreg.CloseKey(base_key) def run_command(command: list[str]) -> str: result = subprocess.run(command, capture_output=True, text=False, check=False) stdout = _decode_output(result.stdout) stderr = _decode_output(result.stderr) if result.returncode != 0: message = stderr or stdout or f"Command failed: {' '.join(command)}" raise WifiCommandError(message.strip()) return stdout def _is_missing_profile_error(message: str, profile_name: str) -> bool: folded = message.casefold() profile_folded = profile_name.casefold() return ( "there is no profile" in folded or "is not found on the system" in folded or f'profile "{profile_folded}"' in folded and "not found" in folded ) def _build_open_profile_xml(profile_name: str, ssid: str, *, connection_mode: str) -> str: escaped_name = _xml_escape(profile_name) escaped_ssid = _xml_escape(ssid) return f""" {escaped_name} {escaped_ssid} false ESS {connection_mode} open none false """ def _xml_escape(value: str) -> str: return ( value.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def _normalize_url(url: str) -> str: return url.strip().rstrip("/").casefold() def _require_adapter_name() -> str: status = get_wifi_status() if not status.interface_name: raise WifiCommandError("Could not find a Wi-Fi adapter from 'netsh wlan show interfaces'") return status.interface_name def _decode_output(raw: bytes) -> str: for encoding in _candidate_encodings(): try: return raw.decode(encoding) except UnicodeDecodeError: continue return raw.decode("utf-8", errors="replace") def _candidate_encodings() -> list[str]: preferred = locale.getpreferredencoding(False) return [preferred, "utf-8", "gbk", "cp936", "big5"] def _parse_key_value_block(block: str) -> dict[str, str]: mapping: dict[str, str] = {} for raw_line in block.splitlines(): line = raw_line.strip() if not line or ":" not in line: continue key, value = line.split(":", 1) normalized = _normalize_key(key.strip()) if normalized: mapping[normalized] = value.strip() return mapping def _normalize_key(key: str) -> str | None: lookup = { "name": "name", "名称": "name", "description": "description", "描述": "description", "state": "state", "状态": "state", "ssid": "ssid", } normalized = key.strip().casefold() if normalized == "bssid": return None return lookup.get(key.strip(), lookup.get(normalized))