Compare commits

...

2 Commits

Author SHA1 Message Date
Fabian Dill
43100f2c43 Test: add a failing test for big int 2025-08-03 10:54:17 +02:00
Fabian Dill
c6df02a355 Core: move MultiServer.py and CommonClient.py to orjson 2025-08-03 08:20:21 +02:00
6 changed files with 81 additions and 59 deletions

View File

@@ -9,7 +9,6 @@ import sys
import typing
import time
import functools
import warnings
import ModuleUpdate
ModuleUpdate.update()
@@ -22,10 +21,11 @@ if __name__ == "__main__":
Utils.init_logging("TextClient", exception_logger="Client")
from MultiServer import CommandProcessor, mark_raw
from NetUtils import (Endpoint, decode, NetworkItem, encode, JSONtoTextParser, ClientStatus, Permission, NetworkSlot,
RawJSONtoTextParser, add_json_text, add_json_location, add_json_item, JSONTypes, HintStatus, SlotType)
from NetUtils import (Endpoint, decode, NetworkItem, JSONtoTextParser, ClientStatus, Permission, NetworkSlot,
RawJSONtoTextParser, add_json_text, add_json_location, add_json_item, JSONTypes, HintStatus,
SlotType, NetworkPlayer, encode_to_bytes)
from Utils import Version, stream_input, async_start
from worlds import network_data_package, AutoWorldRegister
from worlds import network_data_package
import os
import ssl
@@ -502,10 +502,11 @@ class CommonContext:
""" `msgs` JSON serializable """
if not self.server or not self.server.socket.open or self.server.socket.closed:
return
await self.server.socket.send(encode(msgs))
await self.server.socket.send(encode_to_bytes(msgs))
def consume_players_package(self, package: typing.List[tuple]):
self.player_names = {slot: name for team, slot, name, orig_name in package if self.team == team}
def consume_players_package(self, package: typing.List[NetworkPlayer]):
self.player_names = {network_player.slot: network_player.name for network_player in package
if self.team == network_player.team}
self.player_names[0] = "Archipelago"
def event_invalid_slot(self):
@@ -514,11 +515,12 @@ class CommonContext:
def event_invalid_game(self):
raise Exception('Invalid Game; please verify that you connected with the right game to the correct world.')
async def server_auth(self, password_requested: bool = False):
async def server_auth(self, password_requested: bool = False) -> typing.Optional[str]:
if password_requested and not self.password:
logger.info('Enter the password required to join this game:')
self.password = await self.console_input()
return self.password
return None
async def get_username(self):
if not self.auth:
@@ -942,11 +944,10 @@ async def process_server_cmd(ctx: CommonContext, args: dict):
logger.info('--------------------------------')
logger.info('Room Information:')
logger.info('--------------------------------')
version = args["version"]
ctx.server_version = Version(*version)
ctx.server_version = Version.from_network_dict(args["version"])
if "generator_version" in args:
ctx.generator_version = Version(*args["generator_version"])
ctx.generator_version = Version.from_network_dict(args["generator_version"])
logger.info(f'Server protocol version: {ctx.server_version.as_simple_string()}, '
f'generator version: {ctx.generator_version.as_simple_string()}, '
f'tags: {", ".join(args["tags"])}')
@@ -1016,9 +1017,9 @@ async def process_server_cmd(ctx: CommonContext, args: dict):
ctx.slot = args["slot"]
# int keys get lost in JSON transfer
ctx.slot_info = {0: NetworkSlot("Archipelago", "Archipelago", SlotType.player)}
ctx.slot_info.update({int(pid): data for pid, data in args["slot_info"].items()})
ctx.slot_info.update({int(pid): NetworkSlot.from_network_dict(data) for pid, data in args["slot_info"].items()})
ctx.hint_points = args.get("hint_points", 0)
ctx.consume_players_package(args["players"])
ctx.consume_players_package([NetworkPlayer.from_network_dict(player) for player in args["players"]])
ctx.stored_data_notification_keys.add(f"_read_hints_{ctx.team}_{ctx.slot}")
if ctx.game:
game = ctx.game
@@ -1067,17 +1068,17 @@ async def process_server_cmd(ctx: CommonContext, args: dict):
await ctx.send_msgs(sync_msg)
if start_index == len(ctx.items_received):
for item in args['items']:
ctx.items_received.append(NetworkItem(*item))
ctx.items_received.append(NetworkItem.from_network_dict(item))
ctx.watcher_event.set()
elif cmd == 'LocationInfo':
for item in [NetworkItem(*item) for item in args['locations']]:
for item in [NetworkItem.from_network_dict(item) for item in args['locations']]:
ctx.locations_info[item.location] = item
ctx.watcher_event.set()
elif cmd == "RoomUpdate":
if "players" in args:
ctx.consume_players_package(args["players"])
ctx.consume_players_package([NetworkPlayer.from_network_dict(player) for player in args["players"]])
if "hint_points" in args:
ctx.hint_points = args['hint_points']
if "checked_locations" in args:

13
Main.py
View File

@@ -342,7 +342,18 @@ def main(args, seed=None, baked_server_options: dict[str, object] | None = None)
# TODO: change to `"version": version_tuple` after getting better serialization
AutoWorld.call_all(multiworld, "modify_multidata", multidata)
for key in ("slot_data", "er_hint_data"):
base_types_keys = ["er_hint_data"]
# starting with 0.7.0 pre-encode slot data, until then multiserver does it on load
if version_tuple < (0, 7, 0):
base_types_keys.append("slot_data")
else:
for slot, data in multidata["slot_data"].items():
multidata[slot] = NetUtils.encode_to_bytes(data)
assert type(multidata[slot]) is bytes
multidata["minimum_versions"]["server"] = max((0, 7, 0), multidata["minimum_versions"]["server"])
for key in base_types_keys:
multidata[key] = convert_to_base_types(multidata[key])
multidata = zlib.compress(restricted_dumps(multidata), 9)

View File

@@ -31,6 +31,7 @@ if typing.TYPE_CHECKING:
from NetUtils import ServerConnection
import colorama
import orjson
import websockets
from websockets.extensions.permessage_deflate import PerMessageDeflate
try:
@@ -41,9 +42,9 @@ except ImportError:
import NetUtils
import Utils
from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text
from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text, __version__
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
SlotType, LocationStore, MultiData, Hint, HintStatus
SlotType, LocationStore, MultiData, Hint, HintStatus, encode_to_bytes
from BaseClasses import ItemClassification
@@ -169,7 +170,7 @@ team_slot = typing.Tuple[int, int]
class Context:
dumper = staticmethod(encode)
dumper = staticmethod(encode_to_bytes)
loader = staticmethod(decode)
simple_options = {"hint_cost": int,
@@ -453,7 +454,7 @@ class Context:
self.read_data["race_mode"] = lambda: decoded_obj.get("race_mode", 0)
mdata_ver = decoded_obj["minimum_versions"]["server"]
if mdata_ver > version_tuple:
raise RuntimeError(f"Supplied Multidata (.archipelago) requires a server of at least version {mdata_ver},"
raise RuntimeError(f"Supplied Multidata (.archipelago) requires a server of at least version {mdata_ver}, "
f"however this server is of version {version_tuple}")
self.generator_version = Version(*decoded_obj["version"])
clients_ver = decoded_obj["minimum_versions"].get("clients", {})
@@ -490,6 +491,10 @@ class Context:
self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory
self.slot_data = decoded_obj['slot_data']
for slot, data in self.slot_data.items():
if not isinstance(data, bytes):
data = encode_to_bytes(data)
data = orjson.Fragment(data)
self.slot_data[slot] = data
self.read_data[f"slot_data_{slot}"] = lambda data=data: data
self.er_hint_data = {int(player): {int(address): name for address, name in loc_data.items()}
for player, loc_data in decoded_obj["er_hint_data"].items()}
@@ -1786,11 +1791,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
if cmd == 'Connect':
if not args or 'password' not in args or type(args['password']) not in [str, type(None)] or \
'game' not in args:
'game' not in args or "version" not in args:
await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "arguments", 'text': 'Connect',
"original_cmd": cmd}])
return
args["version"] = Version.from_network_dict(args["version"])
errors = set()
if ctx.password and args['password'] != ctx.password:
errors.add('InvalidPassword')
@@ -1806,7 +1811,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
if not ignore_game and args['game'] != game:
errors.add('InvalidGame')
minver = min_client_version if ignore_game else ctx.minimum_client_versions[slot]
if minver > args['version']:
if minver > args["version"]:
errors.add('IncompatibleVersion')
try:
client.items_handling = args['items_handling']
@@ -1814,7 +1819,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
errors.add('InvalidItemsHandling')
# only exact version match allowed
if ctx.compatibility == 0 and args['version'] != version_tuple:
if ctx.compatibility == 0 and args['version'] != Version(__version__):
errors.add('IncompatibleVersion')
if errors:
ctx.logger.info(f"A client connection was refused due to: {errors}, the sent connect information was {args}.")

View File

@@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
import typing
import enum
import warnings
from json import JSONEncoder, JSONDecoder
import orjson
if typing.TYPE_CHECKING:
from websockets import WebSocketServerProtocol as ServerConnection
@@ -78,6 +78,11 @@ class NetworkPlayer(typing.NamedTuple):
alias: str
name: str
@classmethod
def from_network_dict(cls, source: dict):
source.pop("class", None)
return cls(**source)
class NetworkSlot(typing.NamedTuple):
"""Represents a particular slot across teams."""
@@ -86,6 +91,11 @@ class NetworkSlot(typing.NamedTuple):
type: SlotType
group_members: Sequence[int] = () # only populated if type == group
@classmethod
def from_network_dict(cls, source: dict):
source.pop("class", None)
return cls(**source)
class NetworkItem(typing.NamedTuple):
item: int
@@ -94,6 +104,11 @@ class NetworkItem(typing.NamedTuple):
""" Sending player, except in LocationInfo (from LocationScouts), where it is the receiving player. """
flags: int = 0
@classmethod
def from_network_dict(cls, source: dict):
source.pop("class", None)
return cls(**source)
def _scan_for_TypedTuples(obj: typing.Any) -> typing.Any:
if isinstance(obj, tuple) and hasattr(obj, "_fields"): # NamedTuple is not actually a parent class
@@ -128,15 +143,12 @@ def convert_to_base_types(obj: typing.Any) -> _base_types:
raise Exception(f"Cannot handle {type(obj)}")
_encode = JSONEncoder(
ensure_ascii=False,
check_circular=False,
separators=(',', ':'),
).encode
def encode_to_bytes(obj: typing.Any) -> bytes:
return orjson.dumps(_scan_for_TypedTuples(obj), option=orjson.OPT_NON_STR_KEYS)
def encode(obj: typing.Any) -> str:
return _encode(_scan_for_TypedTuples(obj))
return encode_to_bytes(obj).decode()
def get_any_version(data: dict) -> Version:
@@ -144,33 +156,10 @@ def get_any_version(data: dict) -> Version:
return Version(int(data["major"]), int(data["minor"]), int(data["build"]))
allowlist = {
"NetworkPlayer": NetworkPlayer,
"NetworkItem": NetworkItem,
"NetworkSlot": NetworkSlot
}
custom_hooks = {
"Version": get_any_version
}
def _object_hook(o: typing.Any) -> typing.Any:
if isinstance(o, dict):
hook = custom_hooks.get(o.get("class", None), None)
if hook:
return hook(o)
cls = allowlist.get(o.get("class", None), None)
if cls:
for key in tuple(o):
if key not in cls._fields:
del (o[key])
return cls(**o)
return o
decode = JSONDecoder(object_hook=_object_hook).decode
def decode(data: str | bytes) -> typing.Any:
if isinstance(data, str):
data = data.encode()
return orjson.loads(data)
class Endpoint:

View File

@@ -46,6 +46,11 @@ class Version(typing.NamedTuple):
def as_simple_string(self) -> str:
return ".".join(str(item) for item in self)
@classmethod
def from_network_dict(cls, source: dict):
source.pop("class", None)
return cls(**source)
__version__ = "0.6.3"
version_tuple = tuplize_version(__version__)

View File

@@ -0,0 +1,11 @@
import orjson
import unittest
from NetUtils import encode, decode
class TestSerialize(unittest.TestCase):
def test_unbounded_int(self) -> None:
big_number = 2**200
round_tripped_big_number = decode(encode(orjson.Fragment(str(big_number).encode())))
self.assertEqual(big_number, round_tripped_big_number)
self.assertEqual(type(big_number), type(round_tripped_big_number))