diff --git a/MultiServer.py b/MultiServer.py index ed14b6506f..126672a760 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -44,8 +44,9 @@ import NetUtils import Utils from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ - SlotType, LocationStore, MultiData, Hint, HintStatus + SlotType, LocationStore, MultiData, Hint, HintStatus, GamesPackage from BaseClasses import ItemClassification +from apmw.multiserver.gamespackagecache import GamesPackageCache min_client_version = Version(0, 5, 0) @@ -241,21 +242,38 @@ class Context: slot_info: typing.Dict[int, NetworkSlot] generator_version = Version(0, 0, 0) checksums: typing.Dict[str, str] + played_games: set[str] item_names: typing.Dict[str, typing.Dict[int, str]] - item_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]] + item_name_groups: typing.Dict[str, typing.Dict[str, list[str]]] location_names: typing.Dict[str, typing.Dict[int, str]] - location_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]] + location_name_groups: typing.Dict[str, typing.Dict[str, list[str]]] all_item_and_group_names: typing.Dict[str, typing.Set[str]] all_location_and_group_names: typing.Dict[str, typing.Set[str]] non_hintable_names: typing.Dict[str, typing.AbstractSet[str]] spheres: typing.List[typing.Dict[int, typing.Set[int]]] """ each sphere is { player: { location_id, ... } } """ + games_package_cache: GamesPackageCache logger: logging.Logger - def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int, - hint_cost: int, item_cheat: bool, release_mode: str = "disabled", collect_mode="disabled", - countdown_mode: str = "auto", remaining_mode: str = "disabled", auto_shutdown: typing.SupportsFloat = 0, - compatibility: int = 2, log_network: bool = False, logger: logging.Logger = logging.getLogger()): + def __init__( + self, + host: str, + port: int, + server_password: str, + password: str, + location_check_points: int, + hint_cost: int, + item_cheat: bool, + release_mode: str = "disabled", + collect_mode="disabled", + countdown_mode: str = "auto", + remaining_mode: str = "disabled", + auto_shutdown: typing.SupportsFloat = 0, + compatibility: int = 2, + log_network: bool = False, + games_package_cache: GamesPackageCache | None = None, + logger: logging.Logger = logging.getLogger(), + ) -> None: self.logger = logger super(Context, self).__init__() self.slot_info = {} @@ -306,6 +324,7 @@ class Context: self.save_dirty = False self.tags = ['AP'] self.games: typing.Dict[int, str] = {} + self.played_games = set() self.minimum_client_versions: typing.Dict[int, Version] = {} self.seed_name = "" self.groups = {} @@ -315,9 +334,10 @@ class Context: self.stored_data_notification_clients = collections.defaultdict(weakref.WeakSet) self.read_data = {} self.spheres = [] + self.games_package_cache = games_package_cache or GamesPackageCache() # init empty to satisfy linter, I suppose - self.gamespackage = {} + self.reduced_games_package = {} self.checksums = {} self.item_name_groups = {} self.location_name_groups = {} @@ -329,50 +349,11 @@ class Context: lambda: Utils.KeyedDefaultDict(lambda code: f'Unknown location (ID:{code})')) self.non_hintable_names = collections.defaultdict(frozenset) - self._load_game_data() - - # Data package retrieval - def _load_game_data(self): - import worlds - self.gamespackage = worlds.network_data_package["games"] - - self.item_name_groups = {world_name: world.item_name_groups for world_name, world in - worlds.AutoWorldRegister.world_types.items()} - self.location_name_groups = {world_name: world.location_name_groups for world_name, world in - worlds.AutoWorldRegister.world_types.items()} - for world_name, world in worlds.AutoWorldRegister.world_types.items(): - self.non_hintable_names[world_name] = world.hint_blacklist - - for game_package in self.gamespackage.values(): - # remove groups from data sent to clients - del game_package["item_name_groups"] - del game_package["location_name_groups"] - - def _init_game_data(self): - for game_name, game_package in self.gamespackage.items(): - if "checksum" in game_package: - self.checksums[game_name] = game_package["checksum"] - for item_name, item_id in game_package["item_name_to_id"].items(): - self.item_names[game_name][item_id] = item_name - for location_name, location_id in game_package["location_name_to_id"].items(): - self.location_names[game_name][location_id] = location_name - self.all_item_and_group_names[game_name] = \ - set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name]) - self.all_location_and_group_names[game_name] = \ - set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, [])) - - archipelago_item_names = self.item_names["Archipelago"] - archipelago_location_names = self.location_names["Archipelago"] - for game in [game_name for game_name in self.gamespackage if game_name != "Archipelago"]: - # Add Archipelago items and locations to each data package. - self.item_names[game].update(archipelago_item_names) - self.location_names[game].update(archipelago_location_names) - def item_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]: - return self.gamespackage[game]["item_name_to_id"] if game in self.gamespackage else None + return self.reduced_games_package[game]["item_name_to_id"] if game in self.reduced_games_package else None def location_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]: - return self.gamespackage[game]["location_name_to_id"] if game in self.gamespackage else None + return self.reduced_games_package[game]["location_name_to_id"] if game in self.reduced_games_package else None # General networking async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool: @@ -482,19 +463,17 @@ class Context: with open(multidatapath, 'rb') as f: data = f.read() - self._load(self.decompress(data), {}, use_embedded_server_options) + self._load(self.decompress(data), use_embedded_server_options) self.data_filename = multidatapath @staticmethod - def decompress(data: bytes) -> dict: + def decompress(data: bytes) -> typing.Any: format_version = data[0] if format_version > 3: raise Utils.VersionException("Incompatible multidata.") return restricted_loads(zlib.decompress(data[1:])) - def _load(self, decoded_obj: MultiData, game_data_packages: typing.Dict[str, typing.Any], - use_embedded_server_options: bool): - + def _load(self, decoded_obj: MultiData, use_embedded_server_options: bool) -> None: self.read_data = {} # there might be a better place to put this. race_mode = decoded_obj.get("race_mode", 0) @@ -515,6 +494,7 @@ class Context: self.slot_info = decoded_obj["slot_info"] self.games = {slot: slot_info.game for slot, slot_info in self.slot_info.items()} + self.played_games = {"Archipelago"} | {self.games[x] for x in range(1, len(self.games) + 1)} self.groups = {slot: set(slot_info.group_members) for slot, slot_info in self.slot_info.items() if slot_info.type == SlotType.group} @@ -559,18 +539,11 @@ class Context: server_options = decoded_obj.get("server_options", {}) self._set_options(server_options) - # embedded data package - for game_name, data in decoded_obj.get("datapackage", {}).items(): - if game_name in game_data_packages: - data = game_data_packages[game_name] - self.logger.info(f"Loading embedded data package for game {game_name}") - self.gamespackage[game_name] = data - self.item_name_groups[game_name] = data["item_name_groups"] - if "location_name_groups" in data: - self.location_name_groups[game_name] = data["location_name_groups"] - del data["location_name_groups"] - del data["item_name_groups"] # remove from data package, but keep in self.item_name_groups + # load and apply world data and (embedded) data package + self._load_world_data() + self._load_data_package(decoded_obj.get("datapackage", {})) self._init_game_data() + for game_name, data in self.item_name_groups.items(): self.read_data[f"item_name_groups_{game_name}"] = lambda lgame=game_name: self.item_name_groups[lgame] for game_name, data in self.location_name_groups.items(): @@ -579,6 +552,55 @@ class Context: # sorted access spheres self.spheres = decoded_obj.get("spheres", []) + def _load_world_data(self) -> None: + import worlds + + for world_name, world in worlds.AutoWorldRegister.world_types.items(): + # TODO: move hint_blacklist into GamesPackage? + self.non_hintable_names[world_name] = world.hint_blacklist + + def _load_data_package(self, data_package: dict[str, GamesPackage]) -> None: + """Populates reduced_games_package, item_name_groups, location_name_groups from static data and data_package""" + # NOTE: for worlds loaded from db, only checksum is set in GamesPackage, but this is handled by cache + for game_name in sorted(self.played_games): + if game_name in data_package: + self.logger.info(f"Loading embedded data package for game {game_name}") + data = self.games_package_cache.get(game_name, data_package[game_name]) + else: + # NOTE: we still allow uploading a game without datapackage. Once that is changed, we could drop this. + data = self.games_package_cache.get_static(game_name) + ( + self.reduced_games_package[game_name], + self.item_name_groups[game_name], + self.location_name_groups[game_name], + ) = data + + del self.games_package_cache # Not used past this point. Free memory. + + def _init_game_data(self) -> None: + """Update internal values from previously loaded data packages""" + for game_name, game_package in self.reduced_games_package.items(): + if game_name not in self.played_games: + continue + if "checksum" in game_package: + self.checksums[game_name] = game_package["checksum"] + # NOTE: we could save more memory by moving the stuff below to data package cache as well + for item_name, item_id in game_package["item_name_to_id"].items(): + self.item_names[game_name][item_id] = item_name + for location_name, location_id in game_package["location_name_to_id"].items(): + self.location_names[game_name][location_id] = location_name + self.all_item_and_group_names[game_name] = \ + set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name]) + self.all_location_and_group_names[game_name] = \ + set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, [])) + + archipelago_item_names = self.item_names["Archipelago"] + archipelago_location_names = self.location_names["Archipelago"] + for game in [game_name for game_name in self.reduced_games_package if game_name != "Archipelago"]: + # Add Archipelago items and locations to each data package. + self.item_names[game].update(archipelago_item_names) + self.location_names[game].update(archipelago_location_names) + # saving def save(self, now=False) -> bool: @@ -919,12 +941,10 @@ async def server(websocket: "ServerConnection", path: str = "/", ctx: Context = async def on_client_connected(ctx: Context, client: Client): - games = {ctx.games[x] for x in range(1, len(ctx.games) + 1)} - games.add("Archipelago") await ctx.send_msgs(client, [{ 'cmd': 'RoomInfo', 'password': bool(ctx.password), - 'games': games, + 'games': sorted(ctx.played_games), # tags are for additional features in the communication. # Name them by feature or fork, as you feel is appropriate. 'tags': ctx.tags, @@ -933,8 +953,7 @@ async def on_client_connected(ctx: Context, client: Client): 'permissions': get_permissions(ctx), 'hint_cost': ctx.hint_cost, 'location_check_points': ctx.location_check_points, - 'datapackage_checksums': {game: game_data["checksum"] for game, game_data - in ctx.gamespackage.items() if game in games and "checksum" in game_data}, + 'datapackage_checksums': ctx.checksums, 'seed_name': ctx.seed_name, 'time': time.time(), }]) @@ -1940,25 +1959,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict): await ctx.send_msgs(client, reply) elif cmd == "GetDataPackage": - exclusions = args.get("exclusions", []) - if "games" in args: - games = {name: game_data for name, game_data in ctx.gamespackage.items() - if name in set(args.get("games", []))} - await ctx.send_msgs(client, [{"cmd": "DataPackage", - "data": {"games": games}}]) - # TODO: remove exclusions behaviour around 0.5.0 - elif exclusions: - exclusions = set(exclusions) - games = {name: game_data for name, game_data in ctx.gamespackage.items() - if name not in exclusions} - - package = {"games": games} - await ctx.send_msgs(client, [{"cmd": "DataPackage", - "data": package}]) - - else: - await ctx.send_msgs(client, [{"cmd": "DataPackage", - "data": {"games": ctx.gamespackage}}]) + games = { + name: game_data for name, game_data in ctx.reduced_games_package.items() + if name in set(args.get("games", [])) + } + await ctx.send_msgs(client, [{"cmd": "DataPackage", "data": {"games": games}}]) elif client.auth: if cmd == "ConnectUpdate": diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index 4257c6aff3..2cade4960d 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -13,6 +13,7 @@ import threading import time import typing import sys +from asyncio import AbstractEventLoop import websockets from pony.orm import commit, db_session, select @@ -24,8 +25,10 @@ from MultiServer import ( server_per_message_deflate_factory, ) from Utils import restricted_loads, cache_argsless +from NetUtils import GamesPackage +from apmw.webhost.customserver.gamespackagecache import DBGamesPackageCache from .locker import Locker -from .models import Command, GameDataPackage, Room, db +from .models import Command, Room, db class CustomClientMessageProcessor(ClientMessageProcessor): @@ -62,18 +65,39 @@ class DBCommandProcessor(ServerCommandProcessor): class WebHostContext(Context): room_id: int + video: dict[tuple[int, int], tuple[str, str]] + main_loop: AbstractEventLoop + static_server_data: StaticServerData - def __init__(self, static_server_data: dict, logger: logging.Logger): + def __init__( + self, + static_server_data: StaticServerData, + games_package_cache: DBGamesPackageCache, + logger: logging.Logger, + ) -> None: # static server data is used during _load_game_data to load required data, # without needing to import worlds system, which takes quite a bit of memory - self.static_server_data = static_server_data - super(WebHostContext, self).__init__("", 0, "", "", 1, - 40, True, "enabled", "enabled", - "enabled", 0, 2, logger=logger) - del self.static_server_data - self.main_loop = asyncio.get_running_loop() - self.video = {} + super(WebHostContext, self).__init__( + "", + 0, + "", + "", + 1, + 40, + True, + "enabled", + "enabled", + "enabled", + 0, + 2, + games_package_cache=games_package_cache, + logger=logger, + ) self.tags = ["AP", "WebHost"] + self.video = {} + self.main_loop = asyncio.get_running_loop() + self.static_server_data = static_server_data + self.games_package_cache = games_package_cache def __del__(self): try: @@ -83,12 +107,6 @@ class WebHostContext(Context): except ImportError: self.logger.debug("Context destroyed") - def _load_game_data(self): - for key, value in self.static_server_data.items(): - # NOTE: attributes are mutable and shared, so they will have to be copied before being modified - setattr(self, key, value) - self.non_hintable_names = collections.defaultdict(frozenset, self.non_hintable_names) - async def listen_to_db_commands(self): cmdprocessor = DBCommandProcessor(self) @@ -118,42 +136,14 @@ class WebHostContext(Context): self.port = get_random_port() multidata = self.decompress(room.seed.multidata) - game_data_packages = {} + return self._load(multidata, True) - static_gamespackage = self.gamespackage # this is shared across all rooms - static_item_name_groups = self.item_name_groups - static_location_name_groups = self.location_name_groups - self.gamespackage = {"Archipelago": static_gamespackage.get("Archipelago", {})} # this may be modified by _load - self.item_name_groups = {"Archipelago": static_item_name_groups.get("Archipelago", {})} - self.location_name_groups = {"Archipelago": static_location_name_groups.get("Archipelago", {})} - missing_checksum = False - - for game in list(multidata.get("datapackage", {})): - game_data = multidata["datapackage"][game] - if "checksum" in game_data: - if static_gamespackage.get(game, {}).get("checksum") == game_data["checksum"]: - # non-custom. remove from multidata and use static data - # games package could be dropped from static data once all rooms embed data package - del multidata["datapackage"][game] - else: - row = GameDataPackage.get(checksum=game_data["checksum"]) - if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8. multidata should be complete - game_data_packages[game] = restricted_loads(row.data) - continue - else: - self.logger.warning(f"Did not find game_data_package for {game}: {game_data['checksum']}") - else: - missing_checksum = True # Game rolled on old AP and will load data package from multidata - self.gamespackage[game] = static_gamespackage.get(game, {}) - self.item_name_groups[game] = static_item_name_groups.get(game, {}) - self.location_name_groups[game] = static_location_name_groups.get(game, {}) - - if not game_data_packages and not missing_checksum: - # all static -> use the static dicts directly - self.gamespackage = static_gamespackage - self.item_name_groups = static_item_name_groups - self.location_name_groups = static_location_name_groups - return self._load(multidata, game_data_packages, True) + def _load_world_data(self): + # Use static_server_data, but skip static data package since that is in cache anyway. + # Also NOT importing worlds here! + # FIXME: does this copy the non_hintable_names (also for games not part of the room)? + self.non_hintable_names = collections.defaultdict(frozenset, self.static_server_data["non_hintable_names"]) + del self.static_server_data # Not used past this point. Free memory. def init_save(self, enabled: bool = True): self.saving = enabled @@ -185,34 +175,23 @@ def get_random_port(): return random.randint(49152, 65535) +class StaticServerData(typing.TypedDict, total=True): + non_hintable_names: dict[str, typing.AbstractSet[str]] + games_package: dict[str, GamesPackage] + + @cache_argsless -def get_static_server_data() -> dict: +def get_static_server_data() -> StaticServerData: import worlds - data = { + + return { "non_hintable_names": { world_name: world.hint_blacklist for world_name, world in worlds.AutoWorldRegister.world_types.items() }, - "gamespackage": { - world_name: { - key: value - for key, value in game_package.items() - if key not in ("item_name_groups", "location_name_groups") - } - for world_name, game_package in worlds.network_data_package["games"].items() - }, - "item_name_groups": { - world_name: world.item_name_groups - for world_name, world in worlds.AutoWorldRegister.world_types.items() - }, - "location_name_groups": { - world_name: world.location_name_groups - for world_name, world in worlds.AutoWorldRegister.world_types.items() - }, + "games_package": worlds.network_data_package["games"] } - return data - def set_up_logging(room_id) -> logging.Logger: import os @@ -245,9 +224,18 @@ def tear_down_logging(room_id): del logging.Logger.manager.loggerDict[logger_name] -def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, - cert_file: typing.Optional[str], cert_key_file: typing.Optional[str], - host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue): +def run_server_process( + name: str, + ponyconfig: dict[str, typing.Any], + static_server_data: StaticServerData, + cert_file: typing.Optional[str], + cert_key_file: typing.Optional[str], + host: str, + rooms_to_run: multiprocessing.Queue, + rooms_shutting_down: multiprocessing.Queue, +) -> None: + import gc + from setproctitle import setproctitle setproctitle(name) @@ -263,6 +251,9 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit)) del resource, file_limit + # prime the data package cache with static data + games_package_cache = DBGamesPackageCache(static_server_data["games_package"]) + # establish DB connection for multidata and multisave db.bind(**ponyconfig) db.generate_mapping(check_tables=False) @@ -270,8 +261,6 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, if "worlds" in sys.modules: raise Exception("Worlds system should not be loaded in the custom server.") - import gc - if not cert_file: def get_ssl_context(): return None @@ -296,7 +285,7 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, with Locker(f"RoomLocker {room_id}"): try: logger = set_up_logging(room_id) - ctx = WebHostContext(static_server_data, logger) + ctx = WebHostContext(static_server_data, games_package_cache, logger) ctx.load(room_id) ctx.init_save() assert ctx.server is None diff --git a/apmw/__init__.py b/apmw/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/multiserver/__init__.py b/apmw/multiserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/multiserver/gamespackagecache.py b/apmw/multiserver/gamespackagecache.py new file mode 100644 index 0000000000..c90200c744 --- /dev/null +++ b/apmw/multiserver/gamespackagecache.py @@ -0,0 +1,96 @@ +import typing as t +from weakref import WeakValueDictionary + +from NetUtils import GamesPackage + +GameAndChecksum = tuple[str, str | None] +ItemNameGroups = dict[str, list[str]] +LocationNameGroups = dict[str, list[str]] + + +K = t.TypeVar("K") +V = t.TypeVar("V") + + +class DictLike(dict[K, V]): + __slots__ = ("__weakref__",) + + +class GamesPackageCache: + # NOTE: this uses 3 separate collections because unpacking the get() result would end the container lifetime + _reduced_games_packages: WeakValueDictionary[GameAndChecksum, GamesPackage] + """Does not include item_name_groups nor location_name_groups""" + _item_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]] + _location_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]] + + def __init__(self) -> None: + self._reduced_games_packages = WeakValueDictionary() + self._item_name_groups = WeakValueDictionary() + self._location_name_groups = WeakValueDictionary() + + def _get( + self, + cache_key: GameAndChecksum, + ) -> tuple[GamesPackage | None, ItemNameGroups | None, LocationNameGroups | None]: + if cache_key[1] is None: + return None, None, None + return ( + self._reduced_games_packages.get(cache_key, None), + self._item_name_groups.get(cache_key, None), + self._location_name_groups.get(cache_key, None), + ) + + def get( + self, + game: str, + full_games_package: GamesPackage, + ) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + """Loads and caches embedded data package provided by multidata""" + cache_key = (game, full_games_package.get("checksum", None)) + cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups = self._get(cache_key) + + if cached_reduced_games_package is None: + cached_reduced_games_package = t.cast( + t.Any, + DictLike( + { + "item_name_to_id": full_games_package["item_name_to_id"], + "location_name_to_id": full_games_package["location_name_to_id"], + "checksum": full_games_package.get("checksum", None), + } + ), + ) + if cache_key[1] is not None: # only cache if checksum is available + self._reduced_games_packages[cache_key] = cached_reduced_games_package + + if cached_item_name_groups is None: + # optimize strings to be references instead of copies + item_names = {name: name for name in cached_reduced_games_package["item_name_to_id"].keys()} + cached_item_name_groups = DictLike( + { + group_name: [item_names.get(item_name, item_name) for item_name in group_items] + for group_name, group_items in full_games_package["item_name_groups"].items() + } + ) + if cache_key[1] is not None: # only cache if checksum is available + self._item_name_groups[cache_key] = cached_item_name_groups + + if cached_location_name_groups is None: + # optimize strings to be references instead of copies + location_names = {name: name for name in cached_reduced_games_package["location_name_to_id"].keys()} + cached_location_name_groups = DictLike( + { + group_name: [location_names.get(location_name, location_name) for location_name in group_locations] + for group_name, group_locations in full_games_package.get("location_name_groups", {}).items() + } + ) + if cache_key[1] is not None: # only cache if checksum is available + self._location_name_groups[cache_key] = cached_location_name_groups + + return cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups + + def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + """Loads legacy data package from installed worlds""" + import worlds + + return self.get(game, worlds.network_data_package["games"][game]) diff --git a/apmw/webhost/__init__.py b/apmw/webhost/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/webhost/customserver/__init__.py b/apmw/webhost/customserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/webhost/customserver/gamespackagecache.py b/apmw/webhost/customserver/gamespackagecache.py new file mode 100644 index 0000000000..c58af4b4c6 --- /dev/null +++ b/apmw/webhost/customserver/gamespackagecache.py @@ -0,0 +1,42 @@ +from typing_extensions import override + +from NetUtils import GamesPackage +from Utils import restricted_loads +from apmw.multiserver.gamespackagecache import GamesPackageCache, ItemNameGroups, LocationNameGroups + + +class DBGamesPackageCache(GamesPackageCache): + _static: dict[str, tuple[GamesPackage, ItemNameGroups, LocationNameGroups]] + + def __init__(self, static_games_package: dict[str, GamesPackage]) -> None: + super().__init__() + self._static = { + game: GamesPackageCache.get(self, game, games_package) + for game, games_package in static_games_package.items() + } + + @override + def get( + self, + game: str, + full_games_package: GamesPackage, + ) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + # for games started on webhost, full_games_package is likely unpopulated and only has the checksum field + cache_key = (game, full_games_package.get("checksum", None)) + cached = self._get(cache_key) + if any(value is None for value in cached): + if "checksum" not in full_games_package: + return super().get(game, full_games_package) # no checksum, assume fully populated + + from WebHostLib.models import GameDataPackage + + row: GameDataPackage | None = GameDataPackage.get(checksum=full_games_package["checksum"]) + if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8 ... + return super().get(game, restricted_loads(row.data)) + return super().get(game, full_games_package) # ... in which case full_games_package should be populated + + return cached # type: ignore # mypy doesn't understand any value is None + + @override + def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + return self._static[game] diff --git a/test/multiserver/__init__.py b/test/multiserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/multiserver/test_gamespackage_cache.py b/test/multiserver/test_gamespackage_cache.py new file mode 100644 index 0000000000..440a46a08d --- /dev/null +++ b/test/multiserver/test_gamespackage_cache.py @@ -0,0 +1,132 @@ +import typing as t +from copy import deepcopy +from unittest import TestCase + +from typing_extensions import override + +import NetUtils +from NetUtils import GamesPackage +from apmw.multiserver.gamespackagecache import GamesPackageCache + + +class GamesPackageCacheTest(TestCase): + cache: GamesPackageCache + any_game: t.ClassVar[str] = "APQuest" + example_games_package: GamesPackage = { + "item_name_to_id": {"Item 1": 1}, + "item_name_groups": {"Everything": ["Item 1"]}, + "location_name_to_id": {"Location 1": 1}, + "location_name_groups": {"Everywhere": ["Location 1"]}, + "checksum": "1234", + } + + @override + def setUp(self) -> None: + self.cache = GamesPackageCache() + + def test_get_static_is_same(self) -> None: + """Tests that get_static returns the same objects twice""" + reduced_games_package1, item_name_groups1, location_name_groups1 = self.cache.get_static(self.any_game) + reduced_games_package2, item_name_groups2, location_name_groups2 = self.cache.get_static(self.any_game) + self.assertIs(reduced_games_package1, reduced_games_package2) + self.assertIs(item_name_groups1, item_name_groups2) + self.assertIs(location_name_groups1, location_name_groups2) + + def test_get_static_data_format(self) -> None: + """Tests that get_static returns data in the correct format""" + reduced_games_package, item_name_groups, location_name_groups = self.cache.get_static(self.any_game) + self.assertTrue(reduced_games_package["checksum"]) + self.assertTrue(reduced_games_package["item_name_to_id"]) + self.assertTrue(reduced_games_package["location_name_to_id"]) + self.assertNotIn("item_name_groups", reduced_games_package) + self.assertNotIn("location_name_groups", reduced_games_package) + self.assertTrue(item_name_groups["Everything"]) + self.assertTrue(location_name_groups["Everywhere"]) + + def test_get_static_is_serializable(self) -> None: + """Tests that get_static returns data that can be serialized""" + NetUtils.encode(self.cache.get_static(self.any_game)) + + def test_get_static_missing_raises(self) -> None: + """Tests that get_static raises KeyError if the world is missing""" + with self.assertRaises(KeyError): + _ = self.cache.get_static("Does not exist") + + def test_eviction(self) -> None: + """Tests that unused items get evicted from cache""" + game_name = "Test" + before_add = len(self.cache._reduced_games_packages) + data = self.cache.get(game_name, self.example_games_package) + self.assertTrue(data) + self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages)) + + del data + if len(self.cache._reduced_games_packages) != before_add: # gc.collect() may not even be required + import gc + + gc.collect() + + self.assertEqual(before_add, len(self.cache._reduced_games_packages)) + + def test_get_required_field(self) -> None: + """Tests that missing required field raises a KeyError""" + for field in ("item_name_to_id", "location_name_to_id", "item_name_groups"): + with self.subTest(field=field): + games_package = deepcopy(self.example_games_package) + del games_package[field] # type: ignore + with self.assertRaises(KeyError): + _ = self.cache.get(self.any_game, games_package) + + def test_get_optional_properties(self) -> None: + """Tests that missing optional field works""" + for field in ("checksum", "location_name_groups"): + with self.subTest(field=field): + games_package = deepcopy(self.example_games_package) + del games_package[field] # type: ignore + _, item_name_groups, location_name_groups = self.cache.get(self.any_game, games_package) + self.assertTrue(item_name_groups) + self.assertEqual(field != "location_name_groups", bool(location_name_groups)) + + def test_item_name_deduplication(self) -> None: + n = 1 + s1 = f"Item {n}" + s2 = f"Item {n}" + # check if the deduplication is actually gonna do anything + self.assertIsNot(s1, s2) + self.assertEqual(s1, s2) + # do the thing + game_name = "Test" + games_package: GamesPackage = { + "item_name_to_id": {s1: n}, + "item_name_groups": {"Everything": [s2]}, + "location_name_to_id": {}, + "location_name_groups": {}, + "checksum": "1234", + } + reduced_games_package, item_name_groups, location_name_groups = self.cache.get(game_name, games_package) + self.assertIs( + next(iter(reduced_games_package["item_name_to_id"].keys())), + item_name_groups["Everything"][0], + ) + + def test_location_name_deduplication(self) -> None: + n = 1 + s1 = f"Location {n}" + s2 = f"Location {n}" + # check if the deduplication is actually gonna do anything + self.assertIsNot(s1, s2) + self.assertEqual(s1, s2) + # do the thing + game_name = "Test" + games_package: GamesPackage = { + "item_name_to_id": {}, + "item_name_groups": {}, + "location_name_to_id": {s1: n}, + "location_name_groups": {"Everywhere": [s2]}, + "checksum": "1234", + } + reduced_games_package, item_name_groups, location_name_groups = self.cache.get(game_name, games_package) + self.assertIs( + next(iter(reduced_games_package["location_name_to_id"].keys())), + location_name_groups["Everywhere"][0], + ) diff --git a/test/webhost_customserver/__init__.py b/test/webhost_customserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/webhost_customserver/test_gamespackage_cache.py b/test/webhost_customserver/test_gamespackage_cache.py new file mode 100644 index 0000000000..58ea2e37f9 --- /dev/null +++ b/test/webhost_customserver/test_gamespackage_cache.py @@ -0,0 +1,147 @@ +import typing as t +from copy import deepcopy + +from typing_extensions import override + +from test.multiserver.test_gamespackage_cache import GamesPackageCacheTest + +import Utils +from NetUtils import GamesPackage +from apmw.webhost.customserver.gamespackagecache import DBGamesPackageCache + + +class FakeGameDataPackage: + _rows: "t.ClassVar[dict[str, FakeGameDataPackage]]" = {} + data: bytes + + @classmethod + def get(cls, checksum: str) -> "FakeGameDataPackage | None": + return cls._rows.get(checksum, None) + + @classmethod + def add(cls, checksum: str, full_games_package: GamesPackage) -> None: + row = FakeGameDataPackage() + row.data = Utils.restricted_dumps(full_games_package) + cls._rows[checksum] = row + + +class DBGamesPackageCacheTest(GamesPackageCacheTest): + cache: DBGamesPackageCache + any_game: t.ClassVar[str] = "My Game" + static_data: t.ClassVar[dict[str, GamesPackage]] = { # noqa: pycharm doesn't understand this + "My Game": { + "item_name_to_id": {"Item 1": 1}, + "location_name_to_id": {"Location 1": 1}, + "item_name_groups": {"Everything": ["Item 1"]}, + "location_name_groups": {"Everywhere": ["Location 1"]}, + "checksum": "2345", + } + } + orig_db_type: t.ClassVar[type] + + @override + @classmethod + def setUpClass(cls) -> None: + import WebHostLib.models + + cls.orig_db_type = WebHostLib.models.GameDataPackage + WebHostLib.models.GameDataPackage = FakeGameDataPackage # type: ignore + + @override + def setUp(self) -> None: + self.cache = DBGamesPackageCache(self.static_data) + + @override + @classmethod + def tearDownClass(cls) -> None: + import WebHostLib.models + + WebHostLib.models.GameDataPackage = cls.orig_db_type # type: ignore + + def assert_conversion( + self, + full_games_package: GamesPackage, + reduced_games_package: dict[str, t.Any], + item_name_groups: dict[str, t.Any], + location_name_groups: dict[str, t.Any], + ) -> None: + for key in ("item_name_to_id", "location_name_to_id", "checksum"): + if key in full_games_package: + self.assertEqual(reduced_games_package[key], full_games_package[key]) # noqa: pycharm + self.assertEqual(item_name_groups, full_games_package["item_name_groups"]) + self.assertEqual(location_name_groups, full_games_package["location_name_groups"]) + + def assert_static_conversion( + self, + full_games_package: GamesPackage, + reduced_games_package: dict[str, t.Any], + item_name_groups: dict[str, t.Any], + location_name_groups: dict[str, t.Any], + ) -> None: + self.assert_conversion(full_games_package, reduced_games_package, item_name_groups, location_name_groups) + for key in ("item_name_to_id", "location_name_to_id", "checksum"): + self.assertIs(reduced_games_package[key], full_games_package[key]) # noqa: pycharm + + def test_get_static_contents(self) -> None: + """Tests that get_static returns the correct data""" + reduced_games_package, item_name_groups, location_name_groups = self.cache.get_static(self.any_game) + for key in ("item_name_to_id", "location_name_to_id", "checksum"): + self.assertIs(reduced_games_package[key], self.static_data[self.any_game][key]) # noqa: pycharm + self.assertEqual(item_name_groups, self.static_data[self.any_game]["item_name_groups"]) + self.assertEqual(location_name_groups, self.static_data[self.any_game]["location_name_groups"]) + + def test_static_not_evicted(self) -> None: + """Tests that static data is not evicted from cache during gc""" + import gc + + game_name = next(iter(self.static_data.keys())) + ids = [id(o) for o in self.cache.get_static(game_name)] + gc.collect() + self.assertEqual(ids, [id(o) for o in self.cache.get_static(game_name)]) + + def test_get_is_static(self) -> None: + """Tests that a get with correct checksum return the static items""" + # NOTE: this is only true for the DB cache, not the "regular" one, since we want to avoid loading worlds there + cks: GamesPackage = {"checksum": self.static_data[self.any_game]["checksum"]} # noqa: pycharm doesn't like this + reduced_games_package1, item_name_groups1, location_name_groups1 = self.cache.get(self.any_game, cks) + reduced_games_package2, item_name_groups2, location_name_groups2 = self.cache.get_static(self.any_game) + self.assertIs(reduced_games_package1, reduced_games_package2) + self.assertEqual(location_name_groups1, location_name_groups2) + self.assertEqual(item_name_groups1, item_name_groups2) + + def test_get_from_db(self) -> None: + """Tests that a get with only checksum will load the full data from db and is cached""" + game_name = "Another Game" + full_games_package = deepcopy(self.static_data[self.any_game]) + full_games_package["checksum"] = "3456" + cks: GamesPackage = {"checksum": full_games_package["checksum"]} # noqa: pycharm doesn't like this + FakeGameDataPackage.add(full_games_package["checksum"], full_games_package) + before_add = len(self.cache._reduced_games_packages) + data = self.cache.get(game_name, cks) + self.assert_conversion(full_games_package, *data) # type: ignore + self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages)) + + def test_get_missing_from_db_uses_full_games_package(self) -> None: + """Tests that a get with full data (missing from db) will use the full data and is cached""" + game_name = "Yet Another Game" + full_games_package = deepcopy(self.static_data[self.any_game]) + full_games_package["checksum"] = "4567" + before_add = len(self.cache._reduced_games_packages) + data = self.cache.get(game_name, full_games_package) + self.assert_conversion(full_games_package, *data) # type: ignore + self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages)) + + def test_get_without_checksum_uses_full_games_package(self) -> None: + """Tests that a get with full data and no checksum will use the full data and is not cached""" + game_name = "Yet Another Game" + full_games_package = deepcopy(self.static_data[self.any_game]) + del full_games_package["checksum"] + before_add = len(self.cache._reduced_games_packages) + data = self.cache.get(game_name, full_games_package) + self.assert_conversion(full_games_package, *data) # type: ignore + self.assertEqual(before_add, len(self.cache._reduced_games_packages)) + + def test_get_missing_from_db_raises(self) -> None: + """Tests that a get that requires a row to exist raise an exception if it doesn't""" + with self.assertRaises(Exception): + _ = self.cache.get("Does not exist", {"checksum": "0000"})