diff --git a/CommonClient.py b/CommonClient.py index bd7113cb6f..425de74b0f 100644 --- a/CommonClient.py +++ b/CommonClient.py @@ -9,7 +9,6 @@ import sys import typing import time import functools -import warnings import ModuleUpdate ModuleUpdate.update() @@ -22,10 +21,11 @@ if __name__ == "__main__": Utils.init_logging("TextClient", exception_logger="Client") from MultiServer import CommandProcessor, mark_raw -from NetUtils import (Endpoint, decode, NetworkItem, encode, JSONtoTextParser, ClientStatus, Permission, NetworkSlot, - RawJSONtoTextParser, add_json_text, add_json_location, add_json_item, JSONTypes, HintStatus, SlotType) +from NetUtils import (Endpoint, decode, NetworkItem, JSONtoTextParser, ClientStatus, Permission, NetworkSlot, + RawJSONtoTextParser, add_json_text, add_json_location, add_json_item, JSONTypes, HintStatus, + SlotType, NetworkPlayer, encode_to_bytes) from Utils import Version, stream_input, async_start -from worlds import network_data_package, AutoWorldRegister +from worlds import network_data_package import os import ssl @@ -502,10 +502,11 @@ class CommonContext: """ `msgs` JSON serializable """ if not self.server or not self.server.socket.open or self.server.socket.closed: return - await self.server.socket.send(encode(msgs)) + await self.server.socket.send(encode_to_bytes(msgs)) - def consume_players_package(self, package: typing.List[tuple]): - self.player_names = {slot: name for team, slot, name, orig_name in package if self.team == team} + def consume_players_package(self, package: typing.List[NetworkPlayer]): + self.player_names = {network_player.slot: network_player.name for network_player in package + if self.team == network_player.team} self.player_names[0] = "Archipelago" def event_invalid_slot(self): @@ -514,11 +515,12 @@ class CommonContext: def event_invalid_game(self): raise Exception('Invalid Game; please verify that you connected with the right game to the correct world.') - async def server_auth(self, password_requested: bool = False): + async def server_auth(self, password_requested: bool = False) -> typing.Optional[str]: if password_requested and not self.password: logger.info('Enter the password required to join this game:') self.password = await self.console_input() return self.password + return None async def get_username(self): if not self.auth: @@ -942,11 +944,10 @@ async def process_server_cmd(ctx: CommonContext, args: dict): logger.info('--------------------------------') logger.info('Room Information:') logger.info('--------------------------------') - version = args["version"] - ctx.server_version = Version(*version) + ctx.server_version = Version.from_network_dict(args["version"]) if "generator_version" in args: - ctx.generator_version = Version(*args["generator_version"]) + ctx.generator_version = Version.from_network_dict(args["generator_version"]) logger.info(f'Server protocol version: {ctx.server_version.as_simple_string()}, ' f'generator version: {ctx.generator_version.as_simple_string()}, ' f'tags: {", ".join(args["tags"])}') @@ -1016,9 +1017,9 @@ async def process_server_cmd(ctx: CommonContext, args: dict): ctx.slot = args["slot"] # int keys get lost in JSON transfer ctx.slot_info = {0: NetworkSlot("Archipelago", "Archipelago", SlotType.player)} - ctx.slot_info.update({int(pid): data for pid, data in args["slot_info"].items()}) + ctx.slot_info.update({int(pid): NetworkSlot.from_network_dict(data) for pid, data in args["slot_info"].items()}) ctx.hint_points = args.get("hint_points", 0) - ctx.consume_players_package(args["players"]) + ctx.consume_players_package([NetworkPlayer.from_network_dict(player) for player in args["players"]]) ctx.stored_data_notification_keys.add(f"_read_hints_{ctx.team}_{ctx.slot}") if ctx.game: game = ctx.game @@ -1067,17 +1068,17 @@ async def process_server_cmd(ctx: CommonContext, args: dict): await ctx.send_msgs(sync_msg) if start_index == len(ctx.items_received): for item in args['items']: - ctx.items_received.append(NetworkItem(*item)) + ctx.items_received.append(NetworkItem.from_network_dict(item)) ctx.watcher_event.set() elif cmd == 'LocationInfo': - for item in [NetworkItem(*item) for item in args['locations']]: + for item in [NetworkItem.from_network_dict(item) for item in args['locations']]: ctx.locations_info[item.location] = item ctx.watcher_event.set() elif cmd == "RoomUpdate": if "players" in args: - ctx.consume_players_package(args["players"]) + ctx.consume_players_package([NetworkPlayer.from_network_dict(player) for player in args["players"]]) if "hint_points" in args: ctx.hint_points = args['hint_points'] if "checked_locations" in args: diff --git a/Main.py b/Main.py index bc2787579f..ccdd2ab06c 100644 --- a/Main.py +++ b/Main.py @@ -342,7 +342,18 @@ def main(args, seed=None, baked_server_options: dict[str, object] | None = None) # TODO: change to `"version": version_tuple` after getting better serialization AutoWorld.call_all(multiworld, "modify_multidata", multidata) - for key in ("slot_data", "er_hint_data"): + base_types_keys = ["er_hint_data"] + + # starting with 0.7.0 pre-encode slot data, until then multiserver does it on load + if version_tuple < (0, 7, 0): + base_types_keys.append("slot_data") + else: + for slot, data in multidata["slot_data"].items(): + multidata[slot] = NetUtils.encode_to_bytes(data) + assert type(multidata[slot]) is bytes + multidata["minimum_versions"]["server"] = max((0, 7, 0), multidata["minimum_versions"]["server"]) + + for key in base_types_keys: multidata[key] = convert_to_base_types(multidata[key]) multidata = zlib.compress(restricted_dumps(multidata), 9) diff --git a/MultiServer.py b/MultiServer.py index 11a9e394c6..083d23c438 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -31,6 +31,7 @@ if typing.TYPE_CHECKING: from NetUtils import ServerConnection import colorama +import orjson import websockets from websockets.extensions.permessage_deflate import PerMessageDeflate try: @@ -41,9 +42,9 @@ except ImportError: import NetUtils import Utils -from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text +from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text, __version__ from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ - SlotType, LocationStore, MultiData, Hint, HintStatus + SlotType, LocationStore, MultiData, Hint, HintStatus, encode_to_bytes from BaseClasses import ItemClassification @@ -169,7 +170,7 @@ team_slot = typing.Tuple[int, int] class Context: - dumper = staticmethod(encode) + dumper = staticmethod(encode_to_bytes) loader = staticmethod(decode) simple_options = {"hint_cost": int, @@ -453,7 +454,7 @@ class Context: self.read_data["race_mode"] = lambda: decoded_obj.get("race_mode", 0) mdata_ver = decoded_obj["minimum_versions"]["server"] if mdata_ver > version_tuple: - raise RuntimeError(f"Supplied Multidata (.archipelago) requires a server of at least version {mdata_ver}," + raise RuntimeError(f"Supplied Multidata (.archipelago) requires a server of at least version {mdata_ver}, " f"however this server is of version {version_tuple}") self.generator_version = Version(*decoded_obj["version"]) clients_ver = decoded_obj["minimum_versions"].get("clients", {}) @@ -490,6 +491,10 @@ class Context: self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory self.slot_data = decoded_obj['slot_data'] for slot, data in self.slot_data.items(): + if not isinstance(data, bytes): + data = encode_to_bytes(data) + data = orjson.Fragment(data) + self.slot_data[slot] = data self.read_data[f"slot_data_{slot}"] = lambda data=data: data self.er_hint_data = {int(player): {int(address): name for address, name in loc_data.items()} for player, loc_data in decoded_obj["er_hint_data"].items()} @@ -1786,11 +1791,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict): if cmd == 'Connect': if not args or 'password' not in args or type(args['password']) not in [str, type(None)] or \ - 'game' not in args: + 'game' not in args or "version" not in args: await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "arguments", 'text': 'Connect', "original_cmd": cmd}]) return - + args["version"] = Version.from_network_dict(args["version"]) errors = set() if ctx.password and args['password'] != ctx.password: errors.add('InvalidPassword') @@ -1806,7 +1811,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict): if not ignore_game and args['game'] != game: errors.add('InvalidGame') minver = min_client_version if ignore_game else ctx.minimum_client_versions[slot] - if minver > args['version']: + if minver > args["version"]: errors.add('IncompatibleVersion') try: client.items_handling = args['items_handling'] @@ -1814,7 +1819,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict): errors.add('InvalidItemsHandling') # only exact version match allowed - if ctx.compatibility == 0 and args['version'] != version_tuple: + if ctx.compatibility == 0 and args['version'] != Version(__version__): errors.add('IncompatibleVersion') if errors: ctx.logger.info(f"A client connection was refused due to: {errors}, the sent connect information was {args}.") diff --git a/NetUtils.py b/NetUtils.py index 45279183f6..18794dffc8 100644 --- a/NetUtils.py +++ b/NetUtils.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence import typing import enum import warnings -from json import JSONEncoder, JSONDecoder +import orjson if typing.TYPE_CHECKING: from websockets import WebSocketServerProtocol as ServerConnection @@ -78,6 +78,11 @@ class NetworkPlayer(typing.NamedTuple): alias: str name: str + @classmethod + def from_network_dict(cls, source: dict): + source.pop("class", None) + return cls(**source) + class NetworkSlot(typing.NamedTuple): """Represents a particular slot across teams.""" @@ -86,6 +91,11 @@ class NetworkSlot(typing.NamedTuple): type: SlotType group_members: Sequence[int] = () # only populated if type == group + @classmethod + def from_network_dict(cls, source: dict): + source.pop("class", None) + return cls(**source) + class NetworkItem(typing.NamedTuple): item: int @@ -94,6 +104,11 @@ class NetworkItem(typing.NamedTuple): """ Sending player, except in LocationInfo (from LocationScouts), where it is the receiving player. """ flags: int = 0 + @classmethod + def from_network_dict(cls, source: dict): + source.pop("class", None) + return cls(**source) + def _scan_for_TypedTuples(obj: typing.Any) -> typing.Any: if isinstance(obj, tuple) and hasattr(obj, "_fields"): # NamedTuple is not actually a parent class @@ -128,15 +143,12 @@ def convert_to_base_types(obj: typing.Any) -> _base_types: raise Exception(f"Cannot handle {type(obj)}") -_encode = JSONEncoder( - ensure_ascii=False, - check_circular=False, - separators=(',', ':'), -).encode +def encode_to_bytes(obj: typing.Any) -> bytes: + return orjson.dumps(_scan_for_TypedTuples(obj), option=orjson.OPT_NON_STR_KEYS) def encode(obj: typing.Any) -> str: - return _encode(_scan_for_TypedTuples(obj)) + return encode_to_bytes(obj).decode() def get_any_version(data: dict) -> Version: @@ -144,33 +156,10 @@ def get_any_version(data: dict) -> Version: return Version(int(data["major"]), int(data["minor"]), int(data["build"])) -allowlist = { - "NetworkPlayer": NetworkPlayer, - "NetworkItem": NetworkItem, - "NetworkSlot": NetworkSlot -} - -custom_hooks = { - "Version": get_any_version -} - - -def _object_hook(o: typing.Any) -> typing.Any: - if isinstance(o, dict): - hook = custom_hooks.get(o.get("class", None), None) - if hook: - return hook(o) - cls = allowlist.get(o.get("class", None), None) - if cls: - for key in tuple(o): - if key not in cls._fields: - del (o[key]) - return cls(**o) - - return o - - -decode = JSONDecoder(object_hook=_object_hook).decode +def decode(data: str | bytes) -> typing.Any: + if isinstance(data, str): + data = data.encode() + return orjson.loads(data) class Endpoint: diff --git a/Utils.py b/Utils.py index b7616b57b1..cd8409776e 100644 --- a/Utils.py +++ b/Utils.py @@ -46,6 +46,11 @@ class Version(typing.NamedTuple): def as_simple_string(self) -> str: return ".".join(str(item) for item in self) + @classmethod + def from_network_dict(cls, source: dict): + source.pop("class", None) + return cls(**source) + __version__ = "0.6.3" version_tuple = tuplize_version(__version__)