diff --git a/MultiServer.py b/MultiServer.py index 52c80c5540..ac3c9c18c1 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -43,8 +43,9 @@ import NetUtils import Utils from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ - SlotType, LocationStore, MultiData, Hint, HintStatus + SlotType, LocationStore, MultiData, Hint, HintStatus, GamesPackage from BaseClasses import ItemClassification +from apmw.multiserver.gamespackage.cache import GamesPackageCache min_client_version = Version(0, 5, 0) @@ -240,21 +241,38 @@ class Context: slot_info: typing.Dict[int, NetworkSlot] generator_version = Version(0, 0, 0) checksums: typing.Dict[str, str] + played_games: set[str] item_names: typing.Dict[str, typing.Dict[int, str]] - item_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]] + item_name_groups: typing.Dict[str, typing.Dict[str, list[str]]] location_names: typing.Dict[str, typing.Dict[int, str]] - location_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]] + location_name_groups: typing.Dict[str, typing.Dict[str, list[str]]] all_item_and_group_names: typing.Dict[str, typing.Set[str]] all_location_and_group_names: typing.Dict[str, typing.Set[str]] non_hintable_names: typing.Dict[str, typing.AbstractSet[str]] spheres: typing.List[typing.Dict[int, typing.Set[int]]] """ each sphere is { player: { location_id, ... } } """ + games_package_cache: GamesPackageCache logger: logging.Logger - def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int, - hint_cost: int, item_cheat: bool, release_mode: str = "disabled", collect_mode="disabled", - countdown_mode: str = "auto", remaining_mode: str = "disabled", auto_shutdown: typing.SupportsFloat = 0, - compatibility: int = 2, log_network: bool = False, logger: logging.Logger = logging.getLogger()): + def __init__( + self, + host: str, + port: int, + server_password: str, + password: str, + location_check_points: int, + hint_cost: int, + item_cheat: bool, + release_mode: str = "disabled", + collect_mode="disabled", + countdown_mode: str = "auto", + remaining_mode: str = "disabled", + auto_shutdown: typing.SupportsFloat = 0, + compatibility: int = 2, + log_network: bool = False, + games_package_cache: GamesPackageCache | None = None, + logger: logging.Logger = logging.getLogger() + ) -> None: self.logger = logger super(Context, self).__init__() self.slot_info = {} @@ -305,6 +323,7 @@ class Context: self.save_dirty = False self.tags = ['AP'] self.games: typing.Dict[int, str] = {} + self.played_games = set() self.minimum_client_versions: typing.Dict[int, Version] = {} self.seed_name = "" self.groups = {} @@ -314,9 +333,10 @@ class Context: self.stored_data_notification_clients = collections.defaultdict(weakref.WeakSet) self.read_data = {} self.spheres = [] + self.games_package_cache = games_package_cache or GamesPackageCache() # init empty to satisfy linter, I suppose - self.gamespackage = {} + self.reduced_games_package = {} self.checksums = {} self.item_name_groups = {} self.location_name_groups = {} @@ -328,50 +348,11 @@ class Context: lambda: Utils.KeyedDefaultDict(lambda code: f'Unknown location (ID:{code})')) self.non_hintable_names = collections.defaultdict(frozenset) - self._load_game_data() - - # Data package retrieval - def _load_game_data(self): - import worlds - self.gamespackage = worlds.network_data_package["games"] - - self.item_name_groups = {world_name: world.item_name_groups for world_name, world in - worlds.AutoWorldRegister.world_types.items()} - self.location_name_groups = {world_name: world.location_name_groups for world_name, world in - worlds.AutoWorldRegister.world_types.items()} - for world_name, world in worlds.AutoWorldRegister.world_types.items(): - self.non_hintable_names[world_name] = world.hint_blacklist - - for game_package in self.gamespackage.values(): - # remove groups from data sent to clients - del game_package["item_name_groups"] - del game_package["location_name_groups"] - - def _init_game_data(self): - for game_name, game_package in self.gamespackage.items(): - if "checksum" in game_package: - self.checksums[game_name] = game_package["checksum"] - for item_name, item_id in game_package["item_name_to_id"].items(): - self.item_names[game_name][item_id] = item_name - for location_name, location_id in game_package["location_name_to_id"].items(): - self.location_names[game_name][location_id] = location_name - self.all_item_and_group_names[game_name] = \ - set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name]) - self.all_location_and_group_names[game_name] = \ - set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, [])) - - archipelago_item_names = self.item_names["Archipelago"] - archipelago_location_names = self.location_names["Archipelago"] - for game in [game_name for game_name in self.gamespackage if game_name != "Archipelago"]: - # Add Archipelago items and locations to each data package. - self.item_names[game].update(archipelago_item_names) - self.location_names[game].update(archipelago_location_names) - def item_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]: - return self.gamespackage[game]["item_name_to_id"] if game in self.gamespackage else None + return self.reduced_games_package[game]["item_name_to_id"] if game in self.reduced_games_package else None def location_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]: - return self.gamespackage[game]["location_name_to_id"] if game in self.gamespackage else None + return self.reduced_games_package[game]["location_name_to_id"] if game in self.reduced_games_package else None # General networking async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool: @@ -481,19 +462,17 @@ class Context: with open(multidatapath, 'rb') as f: data = f.read() - self._load(self.decompress(data), {}, use_embedded_server_options) + self._load(self.decompress(data), use_embedded_server_options) self.data_filename = multidatapath @staticmethod - def decompress(data: bytes) -> dict: + def decompress(data: bytes) -> typing.Any: format_version = data[0] if format_version > 3: raise Utils.VersionException("Incompatible multidata.") return restricted_loads(zlib.decompress(data[1:])) - def _load(self, decoded_obj: MultiData, game_data_packages: typing.Dict[str, typing.Any], - use_embedded_server_options: bool): - + def _load(self, decoded_obj: MultiData, use_embedded_server_options: bool) -> None: self.read_data = {} # there might be a better place to put this. self.read_data["race_mode"] = lambda: decoded_obj.get("race_mode", 0) @@ -513,6 +492,7 @@ class Context: self.slot_info = decoded_obj["slot_info"] self.games = {slot: slot_info.game for slot, slot_info in self.slot_info.items()} + self.played_games = {"Archipelago"} | {self.games[x] for x in range(1, len(self.games) + 1)} self.groups = {slot: set(slot_info.group_members) for slot, slot_info in self.slot_info.items() if slot_info.type == SlotType.group} @@ -557,18 +537,11 @@ class Context: server_options = decoded_obj.get("server_options", {}) self._set_options(server_options) - # embedded data package - for game_name, data in decoded_obj.get("datapackage", {}).items(): - if game_name in game_data_packages: - data = game_data_packages[game_name] - self.logger.info(f"Loading embedded data package for game {game_name}") - self.gamespackage[game_name] = data - self.item_name_groups[game_name] = data["item_name_groups"] - if "location_name_groups" in data: - self.location_name_groups[game_name] = data["location_name_groups"] - del data["location_name_groups"] - del data["item_name_groups"] # remove from data package, but keep in self.item_name_groups + # load and apply world data and (embedded) data package + self._load_world_data() + self._load_data_package(decoded_obj.get("datapackage", {})) self._init_game_data() + for game_name, data in self.item_name_groups.items(): self.read_data[f"item_name_groups_{game_name}"] = lambda lgame=game_name: self.item_name_groups[lgame] for game_name, data in self.location_name_groups.items(): @@ -577,6 +550,55 @@ class Context: # sorted access spheres self.spheres = decoded_obj.get("spheres", []) + def _load_world_data(self) -> None: + import worlds + + for world_name, world in worlds.AutoWorldRegister.world_types.items(): + # TODO: move hint_blacklist into GamesPackage? + self.non_hintable_names[world_name] = world.hint_blacklist + + def _load_data_package(self, data_package: dict[str, GamesPackage]) -> None: + """Populates reduced_games_package, item_name_groups, location_name_groups from static data and data_package""" + # NOTE: for worlds loaded from db, only checksum is set in GamesPackage, but this is handled by cache + for game_name in sorted(self.played_games): + if game_name in data_package: + self.logger.info(f"Loading embedded data package for game {game_name}") + data = self.games_package_cache.get(game_name, data_package[game_name]) + else: + # NOTE: we still allow uploading a game without datapackage. Once that is changed, we could drop this. + data = self.games_package_cache.get_static(game_name) + ( + self.reduced_games_package[game_name], + self.item_name_groups[game_name], + self.location_name_groups[game_name], + ) = data + + del self.games_package_cache # Not used past this point. Free memory. + + def _init_game_data(self) -> None: + """Update internal values from previously loaded data packages""" + for game_name, game_package in self.reduced_games_package.items(): + if game_name not in self.played_games: + continue + if "checksum" in game_package: + self.checksums[game_name] = game_package["checksum"] + # NOTE: we could save more memory by moving the stuff below to data package cache as well + for item_name, item_id in game_package["item_name_to_id"].items(): + self.item_names[game_name][item_id] = item_name + for location_name, location_id in game_package["location_name_to_id"].items(): + self.location_names[game_name][location_id] = location_name + self.all_item_and_group_names[game_name] = \ + set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name]) + self.all_location_and_group_names[game_name] = \ + set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, [])) + + archipelago_item_names = self.item_names["Archipelago"] + archipelago_location_names = self.location_names["Archipelago"] + for game in [game_name for game_name in self.reduced_games_package if game_name != "Archipelago"]: + # Add Archipelago items and locations to each data package. + self.item_names[game].update(archipelago_item_names) + self.location_names[game].update(archipelago_location_names) + # saving def save(self, now=False) -> bool: @@ -917,12 +939,10 @@ async def server(websocket: "ServerConnection", path: str = "/", ctx: Context = async def on_client_connected(ctx: Context, client: Client): - games = {ctx.games[x] for x in range(1, len(ctx.games) + 1)} - games.add("Archipelago") await ctx.send_msgs(client, [{ 'cmd': 'RoomInfo', 'password': bool(ctx.password), - 'games': games, + 'games': sorted(ctx.played_games), # tags are for additional features in the communication. # Name them by feature or fork, as you feel is appropriate. 'tags': ctx.tags, @@ -931,8 +951,7 @@ async def on_client_connected(ctx: Context, client: Client): 'permissions': get_permissions(ctx), 'hint_cost': ctx.hint_cost, 'location_check_points': ctx.location_check_points, - 'datapackage_checksums': {game: game_data["checksum"] for game, game_data - in ctx.gamespackage.items() if game in games and "checksum" in game_data}, + 'datapackage_checksums': ctx.checksums, 'seed_name': ctx.seed_name, 'time': time.time(), }]) @@ -1931,25 +1950,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict): await ctx.send_msgs(client, reply) elif cmd == "GetDataPackage": - exclusions = args.get("exclusions", []) - if "games" in args: - games = {name: game_data for name, game_data in ctx.gamespackage.items() - if name in set(args.get("games", []))} - await ctx.send_msgs(client, [{"cmd": "DataPackage", - "data": {"games": games}}]) - # TODO: remove exclusions behaviour around 0.5.0 - elif exclusions: - exclusions = set(exclusions) - games = {name: game_data for name, game_data in ctx.gamespackage.items() - if name not in exclusions} - - package = {"games": games} - await ctx.send_msgs(client, [{"cmd": "DataPackage", - "data": package}]) - - else: - await ctx.send_msgs(client, [{"cmd": "DataPackage", - "data": {"games": ctx.gamespackage}}]) + games = { + name: game_data for name, game_data in ctx.reduced_games_package.items() + if name in set(args.get("games", [])) + } + await ctx.send_msgs(client, [{"cmd": "DataPackage", "data": {"games": games}}]) elif client.auth: if cmd == "ConnectUpdate": diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index e353cf2ab2..060afafd37 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -13,6 +13,7 @@ import threading import time import typing import sys +from asyncio import AbstractEventLoop import websockets from pony.orm import commit, db_session, select @@ -24,8 +25,10 @@ from MultiServer import ( server_per_message_deflate_factory, ) from Utils import restricted_loads, cache_argsless +from NetUtils import GamesPackage +from apmw.webhost.customserver.gamespackage.cache import DBGamesPackageCache from .locker import Locker -from .models import Command, GameDataPackage, Room, db +from .models import Command, Room, db class CustomClientMessageProcessor(ClientMessageProcessor): @@ -62,18 +65,39 @@ class DBCommandProcessor(ServerCommandProcessor): class WebHostContext(Context): room_id: int + video: dict[tuple[int, int], tuple[str, str]] + main_loop: AbstractEventLoop + static_server_data: StaticServerData - def __init__(self, static_server_data: dict, logger: logging.Logger): + def __init__( + self, + static_server_data: StaticServerData, + games_package_cache: DBGamesPackageCache, + logger: logging.Logger, + ) -> None: # static server data is used during _load_game_data to load required data, # without needing to import worlds system, which takes quite a bit of memory - self.static_server_data = static_server_data - super(WebHostContext, self).__init__("", 0, "", "", 1, - 40, True, "enabled", "enabled", - "enabled", 0, 2, logger=logger) - del self.static_server_data - self.main_loop = asyncio.get_running_loop() - self.video = {} + super(WebHostContext, self).__init__( + "", + 0, + "", + "", + 1, + 40, + True, + "enabled", + "enabled", + "enabled", + 0, + 2, + games_package_cache=games_package_cache, + logger=logger, + ) self.tags = ["AP", "WebHost"] + self.video = {} + self.main_loop = asyncio.get_running_loop() + self.static_server_data = static_server_data + self.games_package_cache = games_package_cache def __del__(self): try: @@ -83,12 +107,6 @@ class WebHostContext(Context): except ImportError: self.logger.debug("Context destroyed") - def _load_game_data(self): - for key, value in self.static_server_data.items(): - # NOTE: attributes are mutable and shared, so they will have to be copied before being modified - setattr(self, key, value) - self.non_hintable_names = collections.defaultdict(frozenset, self.non_hintable_names) - async def listen_to_db_commands(self): cmdprocessor = DBCommandProcessor(self) @@ -118,42 +136,14 @@ class WebHostContext(Context): self.port = get_random_port() multidata = self.decompress(room.seed.multidata) - game_data_packages = {} + return self._load(multidata, True) - static_gamespackage = self.gamespackage # this is shared across all rooms - static_item_name_groups = self.item_name_groups - static_location_name_groups = self.location_name_groups - self.gamespackage = {"Archipelago": static_gamespackage.get("Archipelago", {})} # this may be modified by _load - self.item_name_groups = {"Archipelago": static_item_name_groups.get("Archipelago", {})} - self.location_name_groups = {"Archipelago": static_location_name_groups.get("Archipelago", {})} - missing_checksum = False - - for game in list(multidata.get("datapackage", {})): - game_data = multidata["datapackage"][game] - if "checksum" in game_data: - if static_gamespackage.get(game, {}).get("checksum") == game_data["checksum"]: - # non-custom. remove from multidata and use static data - # games package could be dropped from static data once all rooms embed data package - del multidata["datapackage"][game] - else: - row = GameDataPackage.get(checksum=game_data["checksum"]) - if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8. multidata should be complete - game_data_packages[game] = restricted_loads(row.data) - continue - else: - self.logger.warning(f"Did not find game_data_package for {game}: {game_data['checksum']}") - else: - missing_checksum = True # Game rolled on old AP and will load data package from multidata - self.gamespackage[game] = static_gamespackage.get(game, {}) - self.item_name_groups[game] = static_item_name_groups.get(game, {}) - self.location_name_groups[game] = static_location_name_groups.get(game, {}) - - if not game_data_packages and not missing_checksum: - # all static -> use the static dicts directly - self.gamespackage = static_gamespackage - self.item_name_groups = static_item_name_groups - self.location_name_groups = static_location_name_groups - return self._load(multidata, game_data_packages, True) + def _load_world_data(self): + # Use static_server_data, but skip static data package since that is in cache anyway. + # Also NOT importing worlds here! + # FIXME: does this copy the non_hintable_names (also for games not part of the room)? + self.non_hintable_names = collections.defaultdict(frozenset, self.static_server_data["non_hintable_names"]) + del self.static_server_data # Not used past this point. Free memory. def init_save(self, enabled: bool = True): self.saving = enabled @@ -185,34 +175,23 @@ def get_random_port(): return random.randint(49152, 65535) +class StaticServerData(typing.TypedDict, total=True): + non_hintable_names: dict[str, typing.AbstractSet[str]] + games_package: dict[str, GamesPackage] + + @cache_argsless -def get_static_server_data() -> dict: +def get_static_server_data() -> StaticServerData: import worlds - data = { + + return { "non_hintable_names": { world_name: world.hint_blacklist for world_name, world in worlds.AutoWorldRegister.world_types.items() }, - "gamespackage": { - world_name: { - key: value - for key, value in game_package.items() - if key not in ("item_name_groups", "location_name_groups") - } - for world_name, game_package in worlds.network_data_package["games"].items() - }, - "item_name_groups": { - world_name: world.item_name_groups - for world_name, world in worlds.AutoWorldRegister.world_types.items() - }, - "location_name_groups": { - world_name: world.location_name_groups - for world_name, world in worlds.AutoWorldRegister.world_types.items() - }, + "games_package": worlds.network_data_package["games"] } - return data - def set_up_logging(room_id) -> logging.Logger: import os @@ -245,9 +224,18 @@ def tear_down_logging(room_id): del logging.Logger.manager.loggerDict[logger_name] -def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, - cert_file: typing.Optional[str], cert_key_file: typing.Optional[str], - host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue): +def run_server_process( + name: str, + ponyconfig: dict[str, typing.Any], + static_server_data: StaticServerData, + cert_file: typing.Optional[str], + cert_key_file: typing.Optional[str], + host: str, + rooms_to_run: multiprocessing.Queue, + rooms_shutting_down: multiprocessing.Queue +) -> None: + import gc + from setproctitle import setproctitle setproctitle(name) @@ -263,6 +251,9 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit)) del resource, file_limit + # prime the data package cache with static data + games_package_cache = DBGamesPackageCache(static_server_data["games_package"]) + # establish DB connection for multidata and multisave db.bind(**ponyconfig) db.generate_mapping(check_tables=False) @@ -270,8 +261,6 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, if "worlds" in sys.modules: raise Exception("Worlds system should not be loaded in the custom server.") - import gc - if not cert_file: def get_ssl_context(): return None @@ -296,7 +285,7 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, with Locker(f"RoomLocker {room_id}"): try: logger = set_up_logging(room_id) - ctx = WebHostContext(static_server_data, logger) + ctx = WebHostContext(static_server_data, games_package_cache, logger) ctx.load(room_id) ctx.init_save() assert ctx.server is None diff --git a/apmw/__init__.py b/apmw/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/multiserver/__init__.py b/apmw/multiserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/multiserver/gamespackage/__init__.py b/apmw/multiserver/gamespackage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/multiserver/gamespackage/cache.py b/apmw/multiserver/gamespackage/cache.py new file mode 100644 index 0000000000..1a4dff9f81 --- /dev/null +++ b/apmw/multiserver/gamespackage/cache.py @@ -0,0 +1,110 @@ +import typing as t +from weakref import WeakValueDictionary + +from NetUtils import GamesPackage, DataPackage + +GameAndChecksum = tuple[str, str | None] +ItemNameGroups = dict[str, list[str]] +LocationNameGroups = dict[str, list[str]] + + +K = t.TypeVar("K") +V = t.TypeVar("V") + + +class DictLike(dict[K, V]): + __slots__ = ("__weakref__",) + + +class GamesPackageCache: + # NOTE: this uses 3 separate collections because unpacking the get() result would end the container lifetime + _reduced_games_packages: WeakValueDictionary[GameAndChecksum, GamesPackage] + """Does not include item_name_groups nor location_name_groups""" + _item_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]] + _location_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]] + + def __init__(self) -> None: + self._reduced_games_packages = WeakValueDictionary() + self._item_name_groups = WeakValueDictionary() + self._location_name_groups = WeakValueDictionary() + + def _cached_item_name(self, key: GameAndChecksum, item_name: str) -> str: + """Returns a reference to an already-stored copy of item_name, or item_name""" + # TODO: there gotta be a better way, but maybe only in a C module? + for cached_item_name in self._reduced_games_packages[key].keys(): + if cached_item_name == item_name: + return cached_item_name + return item_name + + def _cached_location_name(self, key: GameAndChecksum, location_name: str) -> str: + """Returns a reference to an already-stored copy of location_name, or location_name""" + # TODO: as above + for cached_item_name in self._reduced_games_packages[key].keys(): + if cached_item_name == location_name: + return cached_item_name + return location_name + + def _get( + self, + cache_key: GameAndChecksum, + ) -> tuple[GamesPackage | None, ItemNameGroups | None, LocationNameGroups | None]: + if cache_key[1] is None: + return None, None, None + return ( + self._reduced_games_packages.get(cache_key, None), + self._item_name_groups.get(cache_key, None), + self._location_name_groups.get(cache_key, None), + ) + + def get( + self, + game: str, + full_games_package: GamesPackage, + ) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + """Loads and caches embedded data package provided by multidata""" + cache_key = (game, full_games_package.get("checksum", None)) + cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups = self._get(cache_key) + + if cached_reduced_games_package is None: + cached_reduced_games_package = t.cast( + t.Any, + DictLike( + { + "item_name_to_id": full_games_package["item_name_to_id"], + "location_name_to_id": full_games_package["location_name_to_id"], + "checksum": full_games_package["checksum"], + } + ), + ) + if cache_key[1] is not None: # only cache if checksum is available + self._reduced_games_packages[cache_key] = cached_reduced_games_package + + if cached_item_name_groups is None: + cached_item_name_groups = DictLike( + { + group_name: [self._cached_item_name(cache_key, item_name) for item_name in group_items] + for group_name, group_items in full_games_package["item_name_groups"].items() + } + ) + if cache_key[1] is not None: # only cache if checksum is available + self._item_name_groups[cache_key] = cached_item_name_groups + + if cached_location_name_groups is None: + cached_location_name_groups = DictLike( + { + group_name: [ + self._cached_location_name(cache_key, location_name) for location_name in group_locations + ] + for group_name, group_locations in full_games_package.get("location_name_groups", {}).items() + } + ) + if cache_key[1] is not None: # only cache if checksum is available + self._location_name_groups[cache_key] = cached_location_name_groups + + return cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups + + def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + """Loads legacy data package from installed worlds""" + import worlds + + return self.get(game, worlds.network_data_package["games"][game]) diff --git a/apmw/webhost/__init__.py b/apmw/webhost/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/webhost/customserver/__init__.py b/apmw/webhost/customserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/webhost/customserver/gamespackage/__init__.py b/apmw/webhost/customserver/gamespackage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apmw/webhost/customserver/gamespackage/cache.py b/apmw/webhost/customserver/gamespackage/cache.py new file mode 100644 index 0000000000..8e3f7c83e0 --- /dev/null +++ b/apmw/webhost/customserver/gamespackage/cache.py @@ -0,0 +1,32 @@ +import typing as t + +from NetUtils import GamesPackage +from Utils import restricted_loads +from WebHostLib.models import GameDataPackage +from apmw.multiserver.gamespackage.cache import GamesPackageCache, ItemNameGroups, LocationNameGroups + + +class DBGamesPackageCache(GamesPackageCache): + _static: dict[str, tuple[GamesPackage, ItemNameGroups, LocationNameGroups]] + + def __init__(self, static_games_package: dict[str, GamesPackage]) -> None: + super().__init__() + self._static = {game: super().get(game, games_package) for game, games_package in static_games_package.items()} + + def get( + self, + game: str, + full_games_package: GamesPackage, + ) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + # for games started on webhost, full_games_package is likely unpopulated and only has the checksum field + cache_key = (game, full_games_package.get("checksum", None)) + cached = self._get(cache_key) + if any(value is None for value in cached): + row = GameDataPackage.get(checksum=full_games_package["checksum"]) + if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8 ... + return super().get(game, restricted_loads(row.data)) + return super().get(game, full_games_package) # ... in which case full_games_package should be populated + return t.cast(tuple[GamesPackage, ItemNameGroups, LocationNameGroups], cached) + + def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]: + return self._static[game]