Files
OSU-Public-Wi-Fi-Login/osu_wifi_login/windows.py

413 lines
13 KiB
Python

from __future__ import annotations
import locale
import random
import re
import subprocess
import tempfile
import time
import winreg
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import requests
from .config import ConnectivityCheck
class WifiCommandError(RuntimeError):
"""Raised when a Windows networking command fails."""
@dataclass(slots=True)
class WifiStatus:
interface_name: str | None
description: str | None
state: str | None
ssid: str | None
@property
def is_connected(self) -> bool:
state = (self.state or "").lower()
return "connected" in state or "已连接" in state or "已連線" in state
def test_connectivity(checks: Iterable[ConnectivityCheck], timeout_seconds: int = 3) -> bool:
for check in checks:
try:
response = requests.get(
check.url,
timeout=timeout_seconds,
allow_redirects=check.allow_redirects,
)
except requests.RequestException:
continue
if response.status_code != check.expected_status:
continue
if check.require_final_url_match and _normalize_url(response.url) != _normalize_url(check.url):
continue
if check.expected_text is not None and check.expected_text not in response.text:
continue
if check.expected_status == 204 and response.content:
continue
return True
return False
def get_wifi_status(preferred_interface_name: str | None = None) -> WifiStatus:
output = run_command(["netsh", "wlan", "show", "interfaces"])
blocks = [block.strip() for block in re.split(r"(?:\r?\n){2,}", output) if block.strip()]
parsed_blocks = [_parse_key_value_block(block) for block in blocks]
candidate = None
if preferred_interface_name:
preferred = preferred_interface_name.casefold()
for block in parsed_blocks:
if (block.get("name") or "").casefold() == preferred:
candidate = block
break
if candidate is None:
for block in parsed_blocks:
state = (block.get("state") or "").lower()
if "connected" in state or "disconnected" in state or "" in state:
candidate = block
break
if candidate is None:
return WifiStatus(interface_name=None, description=None, state=None, ssid=None)
return WifiStatus(
interface_name=candidate.get("name"),
description=candidate.get("description"),
state=candidate.get("state"),
ssid=candidate.get("ssid"),
)
def restart_adapter(interface_name: str) -> None:
run_command(["netsh", "interface", "set", "interface", interface_name, "disable"])
time.sleep(2)
run_command(["netsh", "interface", "set", "interface", interface_name, "enable"])
def disconnect_wifi(interface_name: str | None = None) -> None:
command = ["netsh", "wlan", "disconnect"]
if interface_name:
command.append(f"interface={interface_name}")
run_command(command)
def connect_wifi(
profile_name: str,
interface_name: str | None = None,
*,
ssid: str | None = None,
auto_create_open_profile: bool = False,
) -> None:
command = ["netsh", "wlan", "connect", f"name={profile_name}"]
if ssid:
command.append(f"ssid={ssid}")
if interface_name:
command.append(f"interface={interface_name}")
try:
run_command(command)
except WifiCommandError as exc:
if auto_create_open_profile and _is_missing_profile_error(str(exc), profile_name):
create_open_wifi_profile(
profile_name=profile_name,
ssid=ssid or profile_name,
interface_name=interface_name,
)
run_command(command)
return
if _is_missing_profile_error(str(exc), profile_name):
available_profiles = list_wifi_profiles(interface_name)
profiles_display = ", ".join(available_profiles) if available_profiles else "<none>"
target = interface_name or "<default>"
raise WifiCommandError(
f'Windows WLAN profile "{profile_name}" is missing on interface {target}. '
f"Available profiles: {profiles_display}",
) from exc
raise
def list_wifi_profiles(interface_name: str | None = None) -> list[str]:
command = ["netsh", "wlan", "show", "profiles"]
if interface_name:
command.append(f"interface={interface_name}")
output = run_command(command)
profiles: list[str] = []
for raw_line in output.splitlines():
line = raw_line.strip()
if not line or ":" not in line:
continue
key, value = line.split(":", 1)
key_lower = key.casefold()
if "profile" not in key_lower and "配置文件" not in key_lower and "設定檔" not in key_lower:
continue
name = value.strip()
if not name or name == "<None>":
continue
profiles.append(name)
return profiles
def create_open_wifi_profile(
profile_name: str,
ssid: str,
interface_name: str | None = None,
*,
connection_mode: str = "auto",
) -> None:
profile_xml = _build_open_profile_xml(profile_name, ssid, connection_mode=connection_mode)
temp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
mode="w",
suffix=".xml",
delete=False,
encoding="utf-8",
) as handle:
handle.write(profile_xml)
temp_path = Path(handle.name)
command = ["netsh", "wlan", "add", "profile", f"filename={temp_path}", "user=current"]
if interface_name:
command.append(f"interface={interface_name}")
run_command(command)
except WifiCommandError as exc:
raise WifiCommandError(
f'Failed to create an open Wi-Fi profile for SSID "{ssid}": {exc}',
) from exc
finally:
if temp_path is not None:
temp_path.unlink(missing_ok=True)
def wait_for_ssid(
target_ssid: str,
timeout_seconds: int,
preferred_interface_name: str | None = None,
) -> WifiStatus:
deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline:
status = get_wifi_status(preferred_interface_name)
if status.is_connected and status.ssid == target_ssid:
return status
time.sleep(1)
return get_wifi_status(preferred_interface_name)
def ensure_hardware_mac(adapter_locator: str | None = None) -> bool:
adapter_name = adapter_locator or _require_adapter_name()
reg_path = find_adapter_registry_key(adapter_name)
if not reg_path:
raise WifiCommandError(f"Could not find registry entry for adapter '{adapter_name}'")
try:
key = winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE,
reg_path,
0,
winreg.KEY_SET_VALUE,
)
try:
winreg.DeleteValue(key, "NetworkAddress")
except FileNotFoundError:
return False
finally:
winreg.CloseKey(key)
except PermissionError as exc:
raise WifiCommandError("Administrator privileges are required to restore hardware MAC") from exc
return True
def randomize_mac(adapter_locator: str | None = None) -> str:
adapter_name = adapter_locator or _require_adapter_name()
reg_path = find_adapter_registry_key(adapter_name)
if not reg_path:
raise WifiCommandError(f"Could not find registry entry for adapter '{adapter_name}'")
mac_bytes = [random.randint(0x00, 0xFF) for _ in range(6)]
mac_bytes[0] = (mac_bytes[0] & 0xFC) | 0x02
mac_compact = "".join(f"{part:02X}" for part in mac_bytes)
try:
key = winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE,
reg_path,
0,
winreg.KEY_SET_VALUE,
)
winreg.SetValueEx(key, "NetworkAddress", 0, winreg.REG_SZ, mac_compact)
winreg.CloseKey(key)
except PermissionError as exc:
raise WifiCommandError("Administrator privileges are required to randomize the MAC address") from exc
return ":".join(f"{part:02X}" for part in mac_bytes)
def find_adapter_registry_key(adapter_name: str) -> str | None:
base_path = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e972-e325-11ce-bfc1-08002be10318}"
try:
base_key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, base_path)
except OSError:
return None
try:
index = 0
adapter_name_folded = adapter_name.casefold()
while True:
try:
subkey_name = winreg.EnumKey(base_key, index)
except OSError:
return None
index += 1
full_path = f"{base_path}\\{subkey_name}"
try:
subkey = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, full_path)
except OSError:
continue
try:
for field_name in ("DriverDesc", "NetCfgInstanceId", "ComponentId"):
try:
value, _ = winreg.QueryValueEx(subkey, field_name)
except OSError:
continue
if adapter_name_folded in str(value).casefold():
return full_path
try:
connection_key = winreg.OpenKey(subkey, "Connection")
except OSError:
connection_key = None
if connection_key is not None:
try:
connection_name, _ = winreg.QueryValueEx(connection_key, "Name")
if adapter_name_folded in str(connection_name).casefold():
return full_path
except OSError:
pass
finally:
winreg.CloseKey(connection_key)
finally:
winreg.CloseKey(subkey)
finally:
winreg.CloseKey(base_key)
def run_command(command: list[str]) -> str:
result = subprocess.run(command, capture_output=True, text=False, check=False)
stdout = _decode_output(result.stdout)
stderr = _decode_output(result.stderr)
if result.returncode != 0:
message = stderr or stdout or f"Command failed: {' '.join(command)}"
raise WifiCommandError(message.strip())
return stdout
def _is_missing_profile_error(message: str, profile_name: str) -> bool:
folded = message.casefold()
profile_folded = profile_name.casefold()
return (
"there is no profile" in folded
or "is not found on the system" in folded
or f'profile "{profile_folded}"' in folded and "not found" in folded
)
def _build_open_profile_xml(profile_name: str, ssid: str, *, connection_mode: str) -> str:
escaped_name = _xml_escape(profile_name)
escaped_ssid = _xml_escape(ssid)
return f"""<?xml version="1.0"?>
<WLANProfile xmlns="http://www.microsoft.com/networking/WLAN/profile/v1">
<name>{escaped_name}</name>
<SSIDConfig>
<SSID>
<name>{escaped_ssid}</name>
</SSID>
<nonBroadcast>false</nonBroadcast>
</SSIDConfig>
<connectionType>ESS</connectionType>
<connectionMode>{connection_mode}</connectionMode>
<MSM>
<security>
<authEncryption>
<authentication>open</authentication>
<encryption>none</encryption>
<useOneX>false</useOneX>
</authEncryption>
</security>
</MSM>
</WLANProfile>
"""
def _xml_escape(value: str) -> str:
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)
def _normalize_url(url: str) -> str:
return url.strip().rstrip("/").casefold()
def _require_adapter_name() -> str:
status = get_wifi_status()
if not status.interface_name:
raise WifiCommandError("Could not find a Wi-Fi adapter from 'netsh wlan show interfaces'")
return status.interface_name
def _decode_output(raw: bytes) -> str:
for encoding in _candidate_encodings():
try:
return raw.decode(encoding)
except UnicodeDecodeError:
continue
return raw.decode("utf-8", errors="replace")
def _candidate_encodings() -> list[str]:
preferred = locale.getpreferredencoding(False)
return [preferred, "utf-8", "gbk", "cp936", "big5"]
def _parse_key_value_block(block: str) -> dict[str, str]:
mapping: dict[str, str] = {}
for raw_line in block.splitlines():
line = raw_line.strip()
if not line or ":" not in line:
continue
key, value = line.split(":", 1)
normalized = _normalize_key(key.strip())
if normalized:
mapping[normalized] = value.strip()
return mapping
def _normalize_key(key: str) -> str | None:
lookup = {
"name": "name",
"名称": "name",
"description": "description",
"描述": "description",
"state": "state",
"状态": "state",
"ssid": "ssid",
}
normalized = key.strip().casefold()
if normalized == "bssid":
return None
return lookup.get(key.strip(), lookup.get(normalized))