Merge branch 'feat/data-package-cache' into active/rc-site

This commit is contained in:
black-sliver
2026-03-10 22:00:09 +01:00
12 changed files with 576 additions and 165 deletions

View File

@@ -44,8 +44,9 @@ import NetUtils
import Utils
from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
SlotType, LocationStore, MultiData, Hint, HintStatus
SlotType, LocationStore, MultiData, Hint, HintStatus, GamesPackage
from BaseClasses import ItemClassification
from apmw.multiserver.gamespackagecache import GamesPackageCache
min_client_version = Version(0, 5, 0)
@@ -241,21 +242,38 @@ class Context:
slot_info: typing.Dict[int, NetworkSlot]
generator_version = Version(0, 0, 0)
checksums: typing.Dict[str, str]
played_games: set[str]
item_names: typing.Dict[str, typing.Dict[int, str]]
item_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]]
item_name_groups: typing.Dict[str, typing.Dict[str, list[str]]]
location_names: typing.Dict[str, typing.Dict[int, str]]
location_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]]
location_name_groups: typing.Dict[str, typing.Dict[str, list[str]]]
all_item_and_group_names: typing.Dict[str, typing.Set[str]]
all_location_and_group_names: typing.Dict[str, typing.Set[str]]
non_hintable_names: typing.Dict[str, typing.AbstractSet[str]]
spheres: typing.List[typing.Dict[int, typing.Set[int]]]
""" each sphere is { player: { location_id, ... } } """
games_package_cache: GamesPackageCache
logger: logging.Logger
def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
hint_cost: int, item_cheat: bool, release_mode: str = "disabled", collect_mode="disabled",
countdown_mode: str = "auto", remaining_mode: str = "disabled", auto_shutdown: typing.SupportsFloat = 0,
compatibility: int = 2, log_network: bool = False, logger: logging.Logger = logging.getLogger()):
def __init__(
self,
host: str,
port: int,
server_password: str,
password: str,
location_check_points: int,
hint_cost: int,
item_cheat: bool,
release_mode: str = "disabled",
collect_mode="disabled",
countdown_mode: str = "auto",
remaining_mode: str = "disabled",
auto_shutdown: typing.SupportsFloat = 0,
compatibility: int = 2,
log_network: bool = False,
games_package_cache: GamesPackageCache | None = None,
logger: logging.Logger = logging.getLogger(),
) -> None:
self.logger = logger
super(Context, self).__init__()
self.slot_info = {}
@@ -306,6 +324,7 @@ class Context:
self.save_dirty = False
self.tags = ['AP']
self.games: typing.Dict[int, str] = {}
self.played_games = set()
self.minimum_client_versions: typing.Dict[int, Version] = {}
self.seed_name = ""
self.groups = {}
@@ -315,9 +334,10 @@ class Context:
self.stored_data_notification_clients = collections.defaultdict(weakref.WeakSet)
self.read_data = {}
self.spheres = []
self.games_package_cache = games_package_cache or GamesPackageCache()
# init empty to satisfy linter, I suppose
self.gamespackage = {}
self.reduced_games_package = {}
self.checksums = {}
self.item_name_groups = {}
self.location_name_groups = {}
@@ -329,50 +349,11 @@ class Context:
lambda: Utils.KeyedDefaultDict(lambda code: f'Unknown location (ID:{code})'))
self.non_hintable_names = collections.defaultdict(frozenset)
self._load_game_data()
# Data package retrieval
def _load_game_data(self):
import worlds
self.gamespackage = worlds.network_data_package["games"]
self.item_name_groups = {world_name: world.item_name_groups for world_name, world in
worlds.AutoWorldRegister.world_types.items()}
self.location_name_groups = {world_name: world.location_name_groups for world_name, world in
worlds.AutoWorldRegister.world_types.items()}
for world_name, world in worlds.AutoWorldRegister.world_types.items():
self.non_hintable_names[world_name] = world.hint_blacklist
for game_package in self.gamespackage.values():
# remove groups from data sent to clients
del game_package["item_name_groups"]
del game_package["location_name_groups"]
def _init_game_data(self):
for game_name, game_package in self.gamespackage.items():
if "checksum" in game_package:
self.checksums[game_name] = game_package["checksum"]
for item_name, item_id in game_package["item_name_to_id"].items():
self.item_names[game_name][item_id] = item_name
for location_name, location_id in game_package["location_name_to_id"].items():
self.location_names[game_name][location_id] = location_name
self.all_item_and_group_names[game_name] = \
set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name])
self.all_location_and_group_names[game_name] = \
set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, []))
archipelago_item_names = self.item_names["Archipelago"]
archipelago_location_names = self.location_names["Archipelago"]
for game in [game_name for game_name in self.gamespackage if game_name != "Archipelago"]:
# Add Archipelago items and locations to each data package.
self.item_names[game].update(archipelago_item_names)
self.location_names[game].update(archipelago_location_names)
def item_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]:
return self.gamespackage[game]["item_name_to_id"] if game in self.gamespackage else None
return self.reduced_games_package[game]["item_name_to_id"] if game in self.reduced_games_package else None
def location_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]:
return self.gamespackage[game]["location_name_to_id"] if game in self.gamespackage else None
return self.reduced_games_package[game]["location_name_to_id"] if game in self.reduced_games_package else None
# General networking
async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
@@ -482,19 +463,17 @@ class Context:
with open(multidatapath, 'rb') as f:
data = f.read()
self._load(self.decompress(data), {}, use_embedded_server_options)
self._load(self.decompress(data), use_embedded_server_options)
self.data_filename = multidatapath
@staticmethod
def decompress(data: bytes) -> dict:
def decompress(data: bytes) -> typing.Any:
format_version = data[0]
if format_version > 3:
raise Utils.VersionException("Incompatible multidata.")
return restricted_loads(zlib.decompress(data[1:]))
def _load(self, decoded_obj: MultiData, game_data_packages: typing.Dict[str, typing.Any],
use_embedded_server_options: bool):
def _load(self, decoded_obj: MultiData, use_embedded_server_options: bool) -> None:
self.read_data = {}
# there might be a better place to put this.
race_mode = decoded_obj.get("race_mode", 0)
@@ -515,6 +494,7 @@ class Context:
self.slot_info = decoded_obj["slot_info"]
self.games = {slot: slot_info.game for slot, slot_info in self.slot_info.items()}
self.played_games = {"Archipelago"} | {self.games[x] for x in range(1, len(self.games) + 1)}
self.groups = {slot: set(slot_info.group_members) for slot, slot_info in self.slot_info.items()
if slot_info.type == SlotType.group}
@@ -559,18 +539,11 @@ class Context:
server_options = decoded_obj.get("server_options", {})
self._set_options(server_options)
# embedded data package
for game_name, data in decoded_obj.get("datapackage", {}).items():
if game_name in game_data_packages:
data = game_data_packages[game_name]
self.logger.info(f"Loading embedded data package for game {game_name}")
self.gamespackage[game_name] = data
self.item_name_groups[game_name] = data["item_name_groups"]
if "location_name_groups" in data:
self.location_name_groups[game_name] = data["location_name_groups"]
del data["location_name_groups"]
del data["item_name_groups"] # remove from data package, but keep in self.item_name_groups
# load and apply world data and (embedded) data package
self._load_world_data()
self._load_data_package(decoded_obj.get("datapackage", {}))
self._init_game_data()
for game_name, data in self.item_name_groups.items():
self.read_data[f"item_name_groups_{game_name}"] = lambda lgame=game_name: self.item_name_groups[lgame]
for game_name, data in self.location_name_groups.items():
@@ -579,6 +552,55 @@ class Context:
# sorted access spheres
self.spheres = decoded_obj.get("spheres", [])
def _load_world_data(self) -> None:
import worlds
for world_name, world in worlds.AutoWorldRegister.world_types.items():
# TODO: move hint_blacklist into GamesPackage?
self.non_hintable_names[world_name] = world.hint_blacklist
def _load_data_package(self, data_package: dict[str, GamesPackage]) -> None:
"""Populates reduced_games_package, item_name_groups, location_name_groups from static data and data_package"""
# NOTE: for worlds loaded from db, only checksum is set in GamesPackage, but this is handled by cache
for game_name in sorted(self.played_games):
if game_name in data_package:
self.logger.info(f"Loading embedded data package for game {game_name}")
data = self.games_package_cache.get(game_name, data_package[game_name])
else:
# NOTE: we still allow uploading a game without datapackage. Once that is changed, we could drop this.
data = self.games_package_cache.get_static(game_name)
(
self.reduced_games_package[game_name],
self.item_name_groups[game_name],
self.location_name_groups[game_name],
) = data
del self.games_package_cache # Not used past this point. Free memory.
def _init_game_data(self) -> None:
"""Update internal values from previously loaded data packages"""
for game_name, game_package in self.reduced_games_package.items():
if game_name not in self.played_games:
continue
if "checksum" in game_package:
self.checksums[game_name] = game_package["checksum"]
# NOTE: we could save more memory by moving the stuff below to data package cache as well
for item_name, item_id in game_package["item_name_to_id"].items():
self.item_names[game_name][item_id] = item_name
for location_name, location_id in game_package["location_name_to_id"].items():
self.location_names[game_name][location_id] = location_name
self.all_item_and_group_names[game_name] = \
set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name])
self.all_location_and_group_names[game_name] = \
set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, []))
archipelago_item_names = self.item_names["Archipelago"]
archipelago_location_names = self.location_names["Archipelago"]
for game in [game_name for game_name in self.reduced_games_package if game_name != "Archipelago"]:
# Add Archipelago items and locations to each data package.
self.item_names[game].update(archipelago_item_names)
self.location_names[game].update(archipelago_location_names)
# saving
def save(self, now=False) -> bool:
@@ -919,12 +941,10 @@ async def server(websocket: "ServerConnection", path: str = "/", ctx: Context =
async def on_client_connected(ctx: Context, client: Client):
games = {ctx.games[x] for x in range(1, len(ctx.games) + 1)}
games.add("Archipelago")
await ctx.send_msgs(client, [{
'cmd': 'RoomInfo',
'password': bool(ctx.password),
'games': games,
'games': sorted(ctx.played_games),
# tags are for additional features in the communication.
# Name them by feature or fork, as you feel is appropriate.
'tags': ctx.tags,
@@ -933,8 +953,7 @@ async def on_client_connected(ctx: Context, client: Client):
'permissions': get_permissions(ctx),
'hint_cost': ctx.hint_cost,
'location_check_points': ctx.location_check_points,
'datapackage_checksums': {game: game_data["checksum"] for game, game_data
in ctx.gamespackage.items() if game in games and "checksum" in game_data},
'datapackage_checksums': ctx.checksums,
'seed_name': ctx.seed_name,
'time': time.time(),
}])
@@ -1940,25 +1959,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
await ctx.send_msgs(client, reply)
elif cmd == "GetDataPackage":
exclusions = args.get("exclusions", [])
if "games" in args:
games = {name: game_data for name, game_data in ctx.gamespackage.items()
if name in set(args.get("games", []))}
await ctx.send_msgs(client, [{"cmd": "DataPackage",
"data": {"games": games}}])
# TODO: remove exclusions behaviour around 0.5.0
elif exclusions:
exclusions = set(exclusions)
games = {name: game_data for name, game_data in ctx.gamespackage.items()
if name not in exclusions}
package = {"games": games}
await ctx.send_msgs(client, [{"cmd": "DataPackage",
"data": package}])
else:
await ctx.send_msgs(client, [{"cmd": "DataPackage",
"data": {"games": ctx.gamespackage}}])
games = {
name: game_data for name, game_data in ctx.reduced_games_package.items()
if name in set(args.get("games", []))
}
await ctx.send_msgs(client, [{"cmd": "DataPackage", "data": {"games": games}}])
elif client.auth:
if cmd == "ConnectUpdate":

View File

@@ -13,6 +13,7 @@ import threading
import time
import typing
import sys
from asyncio import AbstractEventLoop
import websockets
from pony.orm import commit, db_session, select
@@ -24,8 +25,10 @@ from MultiServer import (
server_per_message_deflate_factory,
)
from Utils import restricted_loads, cache_argsless
from NetUtils import GamesPackage
from apmw.webhost.customserver.gamespackagecache import DBGamesPackageCache
from .locker import Locker
from .models import Command, GameDataPackage, Room, db
from .models import Command, Room, db
class CustomClientMessageProcessor(ClientMessageProcessor):
@@ -62,18 +65,39 @@ class DBCommandProcessor(ServerCommandProcessor):
class WebHostContext(Context):
room_id: int
video: dict[tuple[int, int], tuple[str, str]]
main_loop: AbstractEventLoop
static_server_data: StaticServerData
def __init__(self, static_server_data: dict, logger: logging.Logger):
def __init__(
self,
static_server_data: StaticServerData,
games_package_cache: DBGamesPackageCache,
logger: logging.Logger,
) -> None:
# static server data is used during _load_game_data to load required data,
# without needing to import worlds system, which takes quite a bit of memory
self.static_server_data = static_server_data
super(WebHostContext, self).__init__("", 0, "", "", 1,
40, True, "enabled", "enabled",
"enabled", 0, 2, logger=logger)
del self.static_server_data
self.main_loop = asyncio.get_running_loop()
self.video = {}
super(WebHostContext, self).__init__(
"",
0,
"",
"",
1,
40,
True,
"enabled",
"enabled",
"enabled",
0,
2,
games_package_cache=games_package_cache,
logger=logger,
)
self.tags = ["AP", "WebHost"]
self.video = {}
self.main_loop = asyncio.get_running_loop()
self.static_server_data = static_server_data
self.games_package_cache = games_package_cache
def __del__(self):
try:
@@ -83,12 +107,6 @@ class WebHostContext(Context):
except ImportError:
self.logger.debug("Context destroyed")
def _load_game_data(self):
for key, value in self.static_server_data.items():
# NOTE: attributes are mutable and shared, so they will have to be copied before being modified
setattr(self, key, value)
self.non_hintable_names = collections.defaultdict(frozenset, self.non_hintable_names)
async def listen_to_db_commands(self):
cmdprocessor = DBCommandProcessor(self)
@@ -118,42 +136,14 @@ class WebHostContext(Context):
self.port = get_random_port()
multidata = self.decompress(room.seed.multidata)
game_data_packages = {}
return self._load(multidata, True)
static_gamespackage = self.gamespackage # this is shared across all rooms
static_item_name_groups = self.item_name_groups
static_location_name_groups = self.location_name_groups
self.gamespackage = {"Archipelago": static_gamespackage.get("Archipelago", {})} # this may be modified by _load
self.item_name_groups = {"Archipelago": static_item_name_groups.get("Archipelago", {})}
self.location_name_groups = {"Archipelago": static_location_name_groups.get("Archipelago", {})}
missing_checksum = False
for game in list(multidata.get("datapackage", {})):
game_data = multidata["datapackage"][game]
if "checksum" in game_data:
if static_gamespackage.get(game, {}).get("checksum") == game_data["checksum"]:
# non-custom. remove from multidata and use static data
# games package could be dropped from static data once all rooms embed data package
del multidata["datapackage"][game]
else:
row = GameDataPackage.get(checksum=game_data["checksum"])
if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8. multidata should be complete
game_data_packages[game] = restricted_loads(row.data)
continue
else:
self.logger.warning(f"Did not find game_data_package for {game}: {game_data['checksum']}")
else:
missing_checksum = True # Game rolled on old AP and will load data package from multidata
self.gamespackage[game] = static_gamespackage.get(game, {})
self.item_name_groups[game] = static_item_name_groups.get(game, {})
self.location_name_groups[game] = static_location_name_groups.get(game, {})
if not game_data_packages and not missing_checksum:
# all static -> use the static dicts directly
self.gamespackage = static_gamespackage
self.item_name_groups = static_item_name_groups
self.location_name_groups = static_location_name_groups
return self._load(multidata, game_data_packages, True)
def _load_world_data(self):
# Use static_server_data, but skip static data package since that is in cache anyway.
# Also NOT importing worlds here!
# FIXME: does this copy the non_hintable_names (also for games not part of the room)?
self.non_hintable_names = collections.defaultdict(frozenset, self.static_server_data["non_hintable_names"])
del self.static_server_data # Not used past this point. Free memory.
def init_save(self, enabled: bool = True):
self.saving = enabled
@@ -185,34 +175,23 @@ def get_random_port():
return random.randint(49152, 65535)
class StaticServerData(typing.TypedDict, total=True):
non_hintable_names: dict[str, typing.AbstractSet[str]]
games_package: dict[str, GamesPackage]
@cache_argsless
def get_static_server_data() -> dict:
def get_static_server_data() -> StaticServerData:
import worlds
data = {
return {
"non_hintable_names": {
world_name: world.hint_blacklist
for world_name, world in worlds.AutoWorldRegister.world_types.items()
},
"gamespackage": {
world_name: {
key: value
for key, value in game_package.items()
if key not in ("item_name_groups", "location_name_groups")
}
for world_name, game_package in worlds.network_data_package["games"].items()
},
"item_name_groups": {
world_name: world.item_name_groups
for world_name, world in worlds.AutoWorldRegister.world_types.items()
},
"location_name_groups": {
world_name: world.location_name_groups
for world_name, world in worlds.AutoWorldRegister.world_types.items()
},
"games_package": worlds.network_data_package["games"]
}
return data
def set_up_logging(room_id) -> logging.Logger:
import os
@@ -245,9 +224,18 @@ def tear_down_logging(room_id):
del logging.Logger.manager.loggerDict[logger_name]
def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
cert_file: typing.Optional[str], cert_key_file: typing.Optional[str],
host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue):
def run_server_process(
name: str,
ponyconfig: dict[str, typing.Any],
static_server_data: StaticServerData,
cert_file: typing.Optional[str],
cert_key_file: typing.Optional[str],
host: str,
rooms_to_run: multiprocessing.Queue,
rooms_shutting_down: multiprocessing.Queue,
) -> None:
import gc
from setproctitle import setproctitle
setproctitle(name)
@@ -263,6 +251,9 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit))
del resource, file_limit
# prime the data package cache with static data
games_package_cache = DBGamesPackageCache(static_server_data["games_package"])
# establish DB connection for multidata and multisave
db.bind(**ponyconfig)
db.generate_mapping(check_tables=False)
@@ -270,8 +261,6 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
if "worlds" in sys.modules:
raise Exception("Worlds system should not be loaded in the custom server.")
import gc
if not cert_file:
def get_ssl_context():
return None
@@ -296,7 +285,7 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
with Locker(f"RoomLocker {room_id}"):
try:
logger = set_up_logging(room_id)
ctx = WebHostContext(static_server_data, logger)
ctx = WebHostContext(static_server_data, games_package_cache, logger)
ctx.load(room_id)
ctx.init_save()
assert ctx.server is None

0
apmw/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,96 @@
import typing as t
from weakref import WeakValueDictionary
from NetUtils import GamesPackage
GameAndChecksum = tuple[str, str | None]
ItemNameGroups = dict[str, list[str]]
LocationNameGroups = dict[str, list[str]]
K = t.TypeVar("K")
V = t.TypeVar("V")
class DictLike(dict[K, V]):
__slots__ = ("__weakref__",)
class GamesPackageCache:
# NOTE: this uses 3 separate collections because unpacking the get() result would end the container lifetime
_reduced_games_packages: WeakValueDictionary[GameAndChecksum, GamesPackage]
"""Does not include item_name_groups nor location_name_groups"""
_item_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]]
_location_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]]
def __init__(self) -> None:
self._reduced_games_packages = WeakValueDictionary()
self._item_name_groups = WeakValueDictionary()
self._location_name_groups = WeakValueDictionary()
def _get(
self,
cache_key: GameAndChecksum,
) -> tuple[GamesPackage | None, ItemNameGroups | None, LocationNameGroups | None]:
if cache_key[1] is None:
return None, None, None
return (
self._reduced_games_packages.get(cache_key, None),
self._item_name_groups.get(cache_key, None),
self._location_name_groups.get(cache_key, None),
)
def get(
self,
game: str,
full_games_package: GamesPackage,
) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
"""Loads and caches embedded data package provided by multidata"""
cache_key = (game, full_games_package.get("checksum", None))
cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups = self._get(cache_key)
if cached_reduced_games_package is None:
cached_reduced_games_package = t.cast(
t.Any,
DictLike(
{
"item_name_to_id": full_games_package["item_name_to_id"],
"location_name_to_id": full_games_package["location_name_to_id"],
"checksum": full_games_package.get("checksum", None),
}
),
)
if cache_key[1] is not None: # only cache if checksum is available
self._reduced_games_packages[cache_key] = cached_reduced_games_package
if cached_item_name_groups is None:
# optimize strings to be references instead of copies
item_names = {name: name for name in cached_reduced_games_package["item_name_to_id"].keys()}
cached_item_name_groups = DictLike(
{
group_name: [item_names.get(item_name, item_name) for item_name in group_items]
for group_name, group_items in full_games_package["item_name_groups"].items()
}
)
if cache_key[1] is not None: # only cache if checksum is available
self._item_name_groups[cache_key] = cached_item_name_groups
if cached_location_name_groups is None:
# optimize strings to be references instead of copies
location_names = {name: name for name in cached_reduced_games_package["location_name_to_id"].keys()}
cached_location_name_groups = DictLike(
{
group_name: [location_names.get(location_name, location_name) for location_name in group_locations]
for group_name, group_locations in full_games_package.get("location_name_groups", {}).items()
}
)
if cache_key[1] is not None: # only cache if checksum is available
self._location_name_groups[cache_key] = cached_location_name_groups
return cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups
def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
"""Loads legacy data package from installed worlds"""
import worlds
return self.get(game, worlds.network_data_package["games"][game])

0
apmw/webhost/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,42 @@
from typing_extensions import override
from NetUtils import GamesPackage
from Utils import restricted_loads
from apmw.multiserver.gamespackagecache import GamesPackageCache, ItemNameGroups, LocationNameGroups
class DBGamesPackageCache(GamesPackageCache):
_static: dict[str, tuple[GamesPackage, ItemNameGroups, LocationNameGroups]]
def __init__(self, static_games_package: dict[str, GamesPackage]) -> None:
super().__init__()
self._static = {
game: GamesPackageCache.get(self, game, games_package)
for game, games_package in static_games_package.items()
}
@override
def get(
self,
game: str,
full_games_package: GamesPackage,
) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
# for games started on webhost, full_games_package is likely unpopulated and only has the checksum field
cache_key = (game, full_games_package.get("checksum", None))
cached = self._get(cache_key)
if any(value is None for value in cached):
if "checksum" not in full_games_package:
return super().get(game, full_games_package) # no checksum, assume fully populated
from WebHostLib.models import GameDataPackage
row: GameDataPackage | None = GameDataPackage.get(checksum=full_games_package["checksum"])
if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8 ...
return super().get(game, restricted_loads(row.data))
return super().get(game, full_games_package) # ... in which case full_games_package should be populated
return cached # type: ignore # mypy doesn't understand any value is None
@override
def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
return self._static[game]

View File

View File

@@ -0,0 +1,132 @@
import typing as t
from copy import deepcopy
from unittest import TestCase
from typing_extensions import override
import NetUtils
from NetUtils import GamesPackage
from apmw.multiserver.gamespackagecache import GamesPackageCache
class GamesPackageCacheTest(TestCase):
cache: GamesPackageCache
any_game: t.ClassVar[str] = "APQuest"
example_games_package: GamesPackage = {
"item_name_to_id": {"Item 1": 1},
"item_name_groups": {"Everything": ["Item 1"]},
"location_name_to_id": {"Location 1": 1},
"location_name_groups": {"Everywhere": ["Location 1"]},
"checksum": "1234",
}
@override
def setUp(self) -> None:
self.cache = GamesPackageCache()
def test_get_static_is_same(self) -> None:
"""Tests that get_static returns the same objects twice"""
reduced_games_package1, item_name_groups1, location_name_groups1 = self.cache.get_static(self.any_game)
reduced_games_package2, item_name_groups2, location_name_groups2 = self.cache.get_static(self.any_game)
self.assertIs(reduced_games_package1, reduced_games_package2)
self.assertIs(item_name_groups1, item_name_groups2)
self.assertIs(location_name_groups1, location_name_groups2)
def test_get_static_data_format(self) -> None:
"""Tests that get_static returns data in the correct format"""
reduced_games_package, item_name_groups, location_name_groups = self.cache.get_static(self.any_game)
self.assertTrue(reduced_games_package["checksum"])
self.assertTrue(reduced_games_package["item_name_to_id"])
self.assertTrue(reduced_games_package["location_name_to_id"])
self.assertNotIn("item_name_groups", reduced_games_package)
self.assertNotIn("location_name_groups", reduced_games_package)
self.assertTrue(item_name_groups["Everything"])
self.assertTrue(location_name_groups["Everywhere"])
def test_get_static_is_serializable(self) -> None:
"""Tests that get_static returns data that can be serialized"""
NetUtils.encode(self.cache.get_static(self.any_game))
def test_get_static_missing_raises(self) -> None:
"""Tests that get_static raises KeyError if the world is missing"""
with self.assertRaises(KeyError):
_ = self.cache.get_static("Does not exist")
def test_eviction(self) -> None:
"""Tests that unused items get evicted from cache"""
game_name = "Test"
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, self.example_games_package)
self.assertTrue(data)
self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages))
del data
if len(self.cache._reduced_games_packages) != before_add: # gc.collect() may not even be required
import gc
gc.collect()
self.assertEqual(before_add, len(self.cache._reduced_games_packages))
def test_get_required_field(self) -> None:
"""Tests that missing required field raises a KeyError"""
for field in ("item_name_to_id", "location_name_to_id", "item_name_groups"):
with self.subTest(field=field):
games_package = deepcopy(self.example_games_package)
del games_package[field] # type: ignore
with self.assertRaises(KeyError):
_ = self.cache.get(self.any_game, games_package)
def test_get_optional_properties(self) -> None:
"""Tests that missing optional field works"""
for field in ("checksum", "location_name_groups"):
with self.subTest(field=field):
games_package = deepcopy(self.example_games_package)
del games_package[field] # type: ignore
_, item_name_groups, location_name_groups = self.cache.get(self.any_game, games_package)
self.assertTrue(item_name_groups)
self.assertEqual(field != "location_name_groups", bool(location_name_groups))
def test_item_name_deduplication(self) -> None:
n = 1
s1 = f"Item {n}"
s2 = f"Item {n}"
# check if the deduplication is actually gonna do anything
self.assertIsNot(s1, s2)
self.assertEqual(s1, s2)
# do the thing
game_name = "Test"
games_package: GamesPackage = {
"item_name_to_id": {s1: n},
"item_name_groups": {"Everything": [s2]},
"location_name_to_id": {},
"location_name_groups": {},
"checksum": "1234",
}
reduced_games_package, item_name_groups, location_name_groups = self.cache.get(game_name, games_package)
self.assertIs(
next(iter(reduced_games_package["item_name_to_id"].keys())),
item_name_groups["Everything"][0],
)
def test_location_name_deduplication(self) -> None:
n = 1
s1 = f"Location {n}"
s2 = f"Location {n}"
# check if the deduplication is actually gonna do anything
self.assertIsNot(s1, s2)
self.assertEqual(s1, s2)
# do the thing
game_name = "Test"
games_package: GamesPackage = {
"item_name_to_id": {},
"item_name_groups": {},
"location_name_to_id": {s1: n},
"location_name_groups": {"Everywhere": [s2]},
"checksum": "1234",
}
reduced_games_package, item_name_groups, location_name_groups = self.cache.get(game_name, games_package)
self.assertIs(
next(iter(reduced_games_package["location_name_to_id"].keys())),
location_name_groups["Everywhere"][0],
)

View File

View File

@@ -0,0 +1,147 @@
import typing as t
from copy import deepcopy
from typing_extensions import override
from test.multiserver.test_gamespackage_cache import GamesPackageCacheTest
import Utils
from NetUtils import GamesPackage
from apmw.webhost.customserver.gamespackagecache import DBGamesPackageCache
class FakeGameDataPackage:
_rows: "t.ClassVar[dict[str, FakeGameDataPackage]]" = {}
data: bytes
@classmethod
def get(cls, checksum: str) -> "FakeGameDataPackage | None":
return cls._rows.get(checksum, None)
@classmethod
def add(cls, checksum: str, full_games_package: GamesPackage) -> None:
row = FakeGameDataPackage()
row.data = Utils.restricted_dumps(full_games_package)
cls._rows[checksum] = row
class DBGamesPackageCacheTest(GamesPackageCacheTest):
cache: DBGamesPackageCache
any_game: t.ClassVar[str] = "My Game"
static_data: t.ClassVar[dict[str, GamesPackage]] = { # noqa: pycharm doesn't understand this
"My Game": {
"item_name_to_id": {"Item 1": 1},
"location_name_to_id": {"Location 1": 1},
"item_name_groups": {"Everything": ["Item 1"]},
"location_name_groups": {"Everywhere": ["Location 1"]},
"checksum": "2345",
}
}
orig_db_type: t.ClassVar[type]
@override
@classmethod
def setUpClass(cls) -> None:
import WebHostLib.models
cls.orig_db_type = WebHostLib.models.GameDataPackage
WebHostLib.models.GameDataPackage = FakeGameDataPackage # type: ignore
@override
def setUp(self) -> None:
self.cache = DBGamesPackageCache(self.static_data)
@override
@classmethod
def tearDownClass(cls) -> None:
import WebHostLib.models
WebHostLib.models.GameDataPackage = cls.orig_db_type # type: ignore
def assert_conversion(
self,
full_games_package: GamesPackage,
reduced_games_package: dict[str, t.Any],
item_name_groups: dict[str, t.Any],
location_name_groups: dict[str, t.Any],
) -> None:
for key in ("item_name_to_id", "location_name_to_id", "checksum"):
if key in full_games_package:
self.assertEqual(reduced_games_package[key], full_games_package[key]) # noqa: pycharm
self.assertEqual(item_name_groups, full_games_package["item_name_groups"])
self.assertEqual(location_name_groups, full_games_package["location_name_groups"])
def assert_static_conversion(
self,
full_games_package: GamesPackage,
reduced_games_package: dict[str, t.Any],
item_name_groups: dict[str, t.Any],
location_name_groups: dict[str, t.Any],
) -> None:
self.assert_conversion(full_games_package, reduced_games_package, item_name_groups, location_name_groups)
for key in ("item_name_to_id", "location_name_to_id", "checksum"):
self.assertIs(reduced_games_package[key], full_games_package[key]) # noqa: pycharm
def test_get_static_contents(self) -> None:
"""Tests that get_static returns the correct data"""
reduced_games_package, item_name_groups, location_name_groups = self.cache.get_static(self.any_game)
for key in ("item_name_to_id", "location_name_to_id", "checksum"):
self.assertIs(reduced_games_package[key], self.static_data[self.any_game][key]) # noqa: pycharm
self.assertEqual(item_name_groups, self.static_data[self.any_game]["item_name_groups"])
self.assertEqual(location_name_groups, self.static_data[self.any_game]["location_name_groups"])
def test_static_not_evicted(self) -> None:
"""Tests that static data is not evicted from cache during gc"""
import gc
game_name = next(iter(self.static_data.keys()))
ids = [id(o) for o in self.cache.get_static(game_name)]
gc.collect()
self.assertEqual(ids, [id(o) for o in self.cache.get_static(game_name)])
def test_get_is_static(self) -> None:
"""Tests that a get with correct checksum return the static items"""
# NOTE: this is only true for the DB cache, not the "regular" one, since we want to avoid loading worlds there
cks: GamesPackage = {"checksum": self.static_data[self.any_game]["checksum"]} # noqa: pycharm doesn't like this
reduced_games_package1, item_name_groups1, location_name_groups1 = self.cache.get(self.any_game, cks)
reduced_games_package2, item_name_groups2, location_name_groups2 = self.cache.get_static(self.any_game)
self.assertIs(reduced_games_package1, reduced_games_package2)
self.assertEqual(location_name_groups1, location_name_groups2)
self.assertEqual(item_name_groups1, item_name_groups2)
def test_get_from_db(self) -> None:
"""Tests that a get with only checksum will load the full data from db and is cached"""
game_name = "Another Game"
full_games_package = deepcopy(self.static_data[self.any_game])
full_games_package["checksum"] = "3456"
cks: GamesPackage = {"checksum": full_games_package["checksum"]} # noqa: pycharm doesn't like this
FakeGameDataPackage.add(full_games_package["checksum"], full_games_package)
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, cks)
self.assert_conversion(full_games_package, *data) # type: ignore
self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages))
def test_get_missing_from_db_uses_full_games_package(self) -> None:
"""Tests that a get with full data (missing from db) will use the full data and is cached"""
game_name = "Yet Another Game"
full_games_package = deepcopy(self.static_data[self.any_game])
full_games_package["checksum"] = "4567"
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, full_games_package)
self.assert_conversion(full_games_package, *data) # type: ignore
self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages))
def test_get_without_checksum_uses_full_games_package(self) -> None:
"""Tests that a get with full data and no checksum will use the full data and is not cached"""
game_name = "Yet Another Game"
full_games_package = deepcopy(self.static_data[self.any_game])
del full_games_package["checksum"]
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, full_games_package)
self.assert_conversion(full_games_package, *data) # type: ignore
self.assertEqual(before_add, len(self.cache._reduced_games_packages))
def test_get_missing_from_db_raises(self) -> None:
"""Tests that a get that requires a row to exist raise an exception if it doesn't"""
with self.assertRaises(Exception):
_ = self.cache.get("Does not exist", {"checksum": "0000"})