mirror of
https://github.com/ArchipelagoMW/Archipelago.git
synced 2026-03-07 15:13:52 -08:00
Compare commits
2 Commits
main
...
core_orjso
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43100f2c43 | ||
|
|
c6df02a355 |
@@ -9,7 +9,6 @@ import sys
|
||||
import typing
|
||||
import time
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
import ModuleUpdate
|
||||
ModuleUpdate.update()
|
||||
@@ -22,10 +21,11 @@ if __name__ == "__main__":
|
||||
Utils.init_logging("TextClient", exception_logger="Client")
|
||||
|
||||
from MultiServer import CommandProcessor, mark_raw
|
||||
from NetUtils import (Endpoint, decode, NetworkItem, encode, JSONtoTextParser, ClientStatus, Permission, NetworkSlot,
|
||||
RawJSONtoTextParser, add_json_text, add_json_location, add_json_item, JSONTypes, HintStatus, SlotType)
|
||||
from NetUtils import (Endpoint, decode, NetworkItem, JSONtoTextParser, ClientStatus, Permission, NetworkSlot,
|
||||
RawJSONtoTextParser, add_json_text, add_json_location, add_json_item, JSONTypes, HintStatus,
|
||||
SlotType, NetworkPlayer, encode_to_bytes)
|
||||
from Utils import Version, stream_input, async_start
|
||||
from worlds import network_data_package, AutoWorldRegister
|
||||
from worlds import network_data_package
|
||||
import os
|
||||
import ssl
|
||||
|
||||
@@ -502,10 +502,11 @@ class CommonContext:
|
||||
""" `msgs` JSON serializable """
|
||||
if not self.server or not self.server.socket.open or self.server.socket.closed:
|
||||
return
|
||||
await self.server.socket.send(encode(msgs))
|
||||
await self.server.socket.send(encode_to_bytes(msgs))
|
||||
|
||||
def consume_players_package(self, package: typing.List[tuple]):
|
||||
self.player_names = {slot: name for team, slot, name, orig_name in package if self.team == team}
|
||||
def consume_players_package(self, package: typing.List[NetworkPlayer]):
|
||||
self.player_names = {network_player.slot: network_player.name for network_player in package
|
||||
if self.team == network_player.team}
|
||||
self.player_names[0] = "Archipelago"
|
||||
|
||||
def event_invalid_slot(self):
|
||||
@@ -514,11 +515,12 @@ class CommonContext:
|
||||
def event_invalid_game(self):
|
||||
raise Exception('Invalid Game; please verify that you connected with the right game to the correct world.')
|
||||
|
||||
async def server_auth(self, password_requested: bool = False):
|
||||
async def server_auth(self, password_requested: bool = False) -> typing.Optional[str]:
|
||||
if password_requested and not self.password:
|
||||
logger.info('Enter the password required to join this game:')
|
||||
self.password = await self.console_input()
|
||||
return self.password
|
||||
return None
|
||||
|
||||
async def get_username(self):
|
||||
if not self.auth:
|
||||
@@ -942,11 +944,10 @@ async def process_server_cmd(ctx: CommonContext, args: dict):
|
||||
logger.info('--------------------------------')
|
||||
logger.info('Room Information:')
|
||||
logger.info('--------------------------------')
|
||||
version = args["version"]
|
||||
ctx.server_version = Version(*version)
|
||||
ctx.server_version = Version.from_network_dict(args["version"])
|
||||
|
||||
if "generator_version" in args:
|
||||
ctx.generator_version = Version(*args["generator_version"])
|
||||
ctx.generator_version = Version.from_network_dict(args["generator_version"])
|
||||
logger.info(f'Server protocol version: {ctx.server_version.as_simple_string()}, '
|
||||
f'generator version: {ctx.generator_version.as_simple_string()}, '
|
||||
f'tags: {", ".join(args["tags"])}')
|
||||
@@ -1016,9 +1017,9 @@ async def process_server_cmd(ctx: CommonContext, args: dict):
|
||||
ctx.slot = args["slot"]
|
||||
# int keys get lost in JSON transfer
|
||||
ctx.slot_info = {0: NetworkSlot("Archipelago", "Archipelago", SlotType.player)}
|
||||
ctx.slot_info.update({int(pid): data for pid, data in args["slot_info"].items()})
|
||||
ctx.slot_info.update({int(pid): NetworkSlot.from_network_dict(data) for pid, data in args["slot_info"].items()})
|
||||
ctx.hint_points = args.get("hint_points", 0)
|
||||
ctx.consume_players_package(args["players"])
|
||||
ctx.consume_players_package([NetworkPlayer.from_network_dict(player) for player in args["players"]])
|
||||
ctx.stored_data_notification_keys.add(f"_read_hints_{ctx.team}_{ctx.slot}")
|
||||
if ctx.game:
|
||||
game = ctx.game
|
||||
@@ -1067,17 +1068,17 @@ async def process_server_cmd(ctx: CommonContext, args: dict):
|
||||
await ctx.send_msgs(sync_msg)
|
||||
if start_index == len(ctx.items_received):
|
||||
for item in args['items']:
|
||||
ctx.items_received.append(NetworkItem(*item))
|
||||
ctx.items_received.append(NetworkItem.from_network_dict(item))
|
||||
ctx.watcher_event.set()
|
||||
|
||||
elif cmd == 'LocationInfo':
|
||||
for item in [NetworkItem(*item) for item in args['locations']]:
|
||||
for item in [NetworkItem.from_network_dict(item) for item in args['locations']]:
|
||||
ctx.locations_info[item.location] = item
|
||||
ctx.watcher_event.set()
|
||||
|
||||
elif cmd == "RoomUpdate":
|
||||
if "players" in args:
|
||||
ctx.consume_players_package(args["players"])
|
||||
ctx.consume_players_package([NetworkPlayer.from_network_dict(player) for player in args["players"]])
|
||||
if "hint_points" in args:
|
||||
ctx.hint_points = args['hint_points']
|
||||
if "checked_locations" in args:
|
||||
|
||||
13
Main.py
13
Main.py
@@ -342,7 +342,18 @@ def main(args, seed=None, baked_server_options: dict[str, object] | None = None)
|
||||
# TODO: change to `"version": version_tuple` after getting better serialization
|
||||
AutoWorld.call_all(multiworld, "modify_multidata", multidata)
|
||||
|
||||
for key in ("slot_data", "er_hint_data"):
|
||||
base_types_keys = ["er_hint_data"]
|
||||
|
||||
# starting with 0.7.0 pre-encode slot data, until then multiserver does it on load
|
||||
if version_tuple < (0, 7, 0):
|
||||
base_types_keys.append("slot_data")
|
||||
else:
|
||||
for slot, data in multidata["slot_data"].items():
|
||||
multidata[slot] = NetUtils.encode_to_bytes(data)
|
||||
assert type(multidata[slot]) is bytes
|
||||
multidata["minimum_versions"]["server"] = max((0, 7, 0), multidata["minimum_versions"]["server"])
|
||||
|
||||
for key in base_types_keys:
|
||||
multidata[key] = convert_to_base_types(multidata[key])
|
||||
|
||||
multidata = zlib.compress(restricted_dumps(multidata), 9)
|
||||
|
||||
@@ -31,6 +31,7 @@ if typing.TYPE_CHECKING:
|
||||
from NetUtils import ServerConnection
|
||||
|
||||
import colorama
|
||||
import orjson
|
||||
import websockets
|
||||
from websockets.extensions.permessage_deflate import PerMessageDeflate
|
||||
try:
|
||||
@@ -41,9 +42,9 @@ except ImportError:
|
||||
|
||||
import NetUtils
|
||||
import Utils
|
||||
from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text
|
||||
from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text, __version__
|
||||
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
|
||||
SlotType, LocationStore, MultiData, Hint, HintStatus
|
||||
SlotType, LocationStore, MultiData, Hint, HintStatus, encode_to_bytes
|
||||
from BaseClasses import ItemClassification
|
||||
|
||||
|
||||
@@ -169,7 +170,7 @@ team_slot = typing.Tuple[int, int]
|
||||
|
||||
|
||||
class Context:
|
||||
dumper = staticmethod(encode)
|
||||
dumper = staticmethod(encode_to_bytes)
|
||||
loader = staticmethod(decode)
|
||||
|
||||
simple_options = {"hint_cost": int,
|
||||
@@ -453,7 +454,7 @@ class Context:
|
||||
self.read_data["race_mode"] = lambda: decoded_obj.get("race_mode", 0)
|
||||
mdata_ver = decoded_obj["minimum_versions"]["server"]
|
||||
if mdata_ver > version_tuple:
|
||||
raise RuntimeError(f"Supplied Multidata (.archipelago) requires a server of at least version {mdata_ver},"
|
||||
raise RuntimeError(f"Supplied Multidata (.archipelago) requires a server of at least version {mdata_ver}, "
|
||||
f"however this server is of version {version_tuple}")
|
||||
self.generator_version = Version(*decoded_obj["version"])
|
||||
clients_ver = decoded_obj["minimum_versions"].get("clients", {})
|
||||
@@ -490,6 +491,10 @@ class Context:
|
||||
self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory
|
||||
self.slot_data = decoded_obj['slot_data']
|
||||
for slot, data in self.slot_data.items():
|
||||
if not isinstance(data, bytes):
|
||||
data = encode_to_bytes(data)
|
||||
data = orjson.Fragment(data)
|
||||
self.slot_data[slot] = data
|
||||
self.read_data[f"slot_data_{slot}"] = lambda data=data: data
|
||||
self.er_hint_data = {int(player): {int(address): name for address, name in loc_data.items()}
|
||||
for player, loc_data in decoded_obj["er_hint_data"].items()}
|
||||
@@ -1786,11 +1791,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
|
||||
|
||||
if cmd == 'Connect':
|
||||
if not args or 'password' not in args or type(args['password']) not in [str, type(None)] or \
|
||||
'game' not in args:
|
||||
'game' not in args or "version" not in args:
|
||||
await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "arguments", 'text': 'Connect',
|
||||
"original_cmd": cmd}])
|
||||
return
|
||||
|
||||
args["version"] = Version.from_network_dict(args["version"])
|
||||
errors = set()
|
||||
if ctx.password and args['password'] != ctx.password:
|
||||
errors.add('InvalidPassword')
|
||||
@@ -1806,7 +1811,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
|
||||
if not ignore_game and args['game'] != game:
|
||||
errors.add('InvalidGame')
|
||||
minver = min_client_version if ignore_game else ctx.minimum_client_versions[slot]
|
||||
if minver > args['version']:
|
||||
if minver > args["version"]:
|
||||
errors.add('IncompatibleVersion')
|
||||
try:
|
||||
client.items_handling = args['items_handling']
|
||||
@@ -1814,7 +1819,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
|
||||
errors.add('InvalidItemsHandling')
|
||||
|
||||
# only exact version match allowed
|
||||
if ctx.compatibility == 0 and args['version'] != version_tuple:
|
||||
if ctx.compatibility == 0 and args['version'] != Version(__version__):
|
||||
errors.add('IncompatibleVersion')
|
||||
if errors:
|
||||
ctx.logger.info(f"A client connection was refused due to: {errors}, the sent connect information was {args}.")
|
||||
|
||||
57
NetUtils.py
57
NetUtils.py
@@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
import typing
|
||||
import enum
|
||||
import warnings
|
||||
from json import JSONEncoder, JSONDecoder
|
||||
import orjson
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from websockets import WebSocketServerProtocol as ServerConnection
|
||||
@@ -78,6 +78,11 @@ class NetworkPlayer(typing.NamedTuple):
|
||||
alias: str
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_network_dict(cls, source: dict):
|
||||
source.pop("class", None)
|
||||
return cls(**source)
|
||||
|
||||
|
||||
class NetworkSlot(typing.NamedTuple):
|
||||
"""Represents a particular slot across teams."""
|
||||
@@ -86,6 +91,11 @@ class NetworkSlot(typing.NamedTuple):
|
||||
type: SlotType
|
||||
group_members: Sequence[int] = () # only populated if type == group
|
||||
|
||||
@classmethod
|
||||
def from_network_dict(cls, source: dict):
|
||||
source.pop("class", None)
|
||||
return cls(**source)
|
||||
|
||||
|
||||
class NetworkItem(typing.NamedTuple):
|
||||
item: int
|
||||
@@ -94,6 +104,11 @@ class NetworkItem(typing.NamedTuple):
|
||||
""" Sending player, except in LocationInfo (from LocationScouts), where it is the receiving player. """
|
||||
flags: int = 0
|
||||
|
||||
@classmethod
|
||||
def from_network_dict(cls, source: dict):
|
||||
source.pop("class", None)
|
||||
return cls(**source)
|
||||
|
||||
|
||||
def _scan_for_TypedTuples(obj: typing.Any) -> typing.Any:
|
||||
if isinstance(obj, tuple) and hasattr(obj, "_fields"): # NamedTuple is not actually a parent class
|
||||
@@ -128,15 +143,12 @@ def convert_to_base_types(obj: typing.Any) -> _base_types:
|
||||
raise Exception(f"Cannot handle {type(obj)}")
|
||||
|
||||
|
||||
_encode = JSONEncoder(
|
||||
ensure_ascii=False,
|
||||
check_circular=False,
|
||||
separators=(',', ':'),
|
||||
).encode
|
||||
def encode_to_bytes(obj: typing.Any) -> bytes:
|
||||
return orjson.dumps(_scan_for_TypedTuples(obj), option=orjson.OPT_NON_STR_KEYS)
|
||||
|
||||
|
||||
def encode(obj: typing.Any) -> str:
|
||||
return _encode(_scan_for_TypedTuples(obj))
|
||||
return encode_to_bytes(obj).decode()
|
||||
|
||||
|
||||
def get_any_version(data: dict) -> Version:
|
||||
@@ -144,33 +156,10 @@ def get_any_version(data: dict) -> Version:
|
||||
return Version(int(data["major"]), int(data["minor"]), int(data["build"]))
|
||||
|
||||
|
||||
allowlist = {
|
||||
"NetworkPlayer": NetworkPlayer,
|
||||
"NetworkItem": NetworkItem,
|
||||
"NetworkSlot": NetworkSlot
|
||||
}
|
||||
|
||||
custom_hooks = {
|
||||
"Version": get_any_version
|
||||
}
|
||||
|
||||
|
||||
def _object_hook(o: typing.Any) -> typing.Any:
|
||||
if isinstance(o, dict):
|
||||
hook = custom_hooks.get(o.get("class", None), None)
|
||||
if hook:
|
||||
return hook(o)
|
||||
cls = allowlist.get(o.get("class", None), None)
|
||||
if cls:
|
||||
for key in tuple(o):
|
||||
if key not in cls._fields:
|
||||
del (o[key])
|
||||
return cls(**o)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
decode = JSONDecoder(object_hook=_object_hook).decode
|
||||
def decode(data: str | bytes) -> typing.Any:
|
||||
if isinstance(data, str):
|
||||
data = data.encode()
|
||||
return orjson.loads(data)
|
||||
|
||||
|
||||
class Endpoint:
|
||||
|
||||
5
Utils.py
5
Utils.py
@@ -46,6 +46,11 @@ class Version(typing.NamedTuple):
|
||||
def as_simple_string(self) -> str:
|
||||
return ".".join(str(item) for item in self)
|
||||
|
||||
@classmethod
|
||||
def from_network_dict(cls, source: dict):
|
||||
source.pop("class", None)
|
||||
return cls(**source)
|
||||
|
||||
|
||||
__version__ = "0.6.3"
|
||||
version_tuple = tuplize_version(__version__)
|
||||
|
||||
11
test/netutils/test_serialize.py
Normal file
11
test/netutils/test_serialize.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import orjson
|
||||
import unittest
|
||||
from NetUtils import encode, decode
|
||||
|
||||
|
||||
class TestSerialize(unittest.TestCase):
|
||||
def test_unbounded_int(self) -> None:
|
||||
big_number = 2**200
|
||||
round_tripped_big_number = decode(encode(orjson.Fragment(str(big_number).encode())))
|
||||
self.assertEqual(big_number, round_tripped_big_number)
|
||||
self.assertEqual(type(big_number), type(round_tripped_big_number))
|
||||
Reference in New Issue
Block a user