Merge branch 'feat/data-package-cache' into active/rc-site

This commit is contained in:
black-sliver
2026-03-10 22:00:09 +01:00
12 changed files with 576 additions and 165 deletions
+94 -89
View File
@@ -44,8 +44,9 @@ import NetUtils
import Utils import Utils
from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text from Utils import version_tuple, restricted_loads, Version, async_start, get_intended_text
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
SlotType, LocationStore, MultiData, Hint, HintStatus SlotType, LocationStore, MultiData, Hint, HintStatus, GamesPackage
from BaseClasses import ItemClassification from BaseClasses import ItemClassification
from apmw.multiserver.gamespackagecache import GamesPackageCache
min_client_version = Version(0, 5, 0) min_client_version = Version(0, 5, 0)
@@ -241,21 +242,38 @@ class Context:
slot_info: typing.Dict[int, NetworkSlot] slot_info: typing.Dict[int, NetworkSlot]
generator_version = Version(0, 0, 0) generator_version = Version(0, 0, 0)
checksums: typing.Dict[str, str] checksums: typing.Dict[str, str]
played_games: set[str]
item_names: typing.Dict[str, typing.Dict[int, str]] item_names: typing.Dict[str, typing.Dict[int, str]]
item_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]] item_name_groups: typing.Dict[str, typing.Dict[str, list[str]]]
location_names: typing.Dict[str, typing.Dict[int, str]] location_names: typing.Dict[str, typing.Dict[int, str]]
location_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]] location_name_groups: typing.Dict[str, typing.Dict[str, list[str]]]
all_item_and_group_names: typing.Dict[str, typing.Set[str]] all_item_and_group_names: typing.Dict[str, typing.Set[str]]
all_location_and_group_names: typing.Dict[str, typing.Set[str]] all_location_and_group_names: typing.Dict[str, typing.Set[str]]
non_hintable_names: typing.Dict[str, typing.AbstractSet[str]] non_hintable_names: typing.Dict[str, typing.AbstractSet[str]]
spheres: typing.List[typing.Dict[int, typing.Set[int]]] spheres: typing.List[typing.Dict[int, typing.Set[int]]]
""" each sphere is { player: { location_id, ... } } """ """ each sphere is { player: { location_id, ... } } """
games_package_cache: GamesPackageCache
logger: logging.Logger logger: logging.Logger
def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int, def __init__(
hint_cost: int, item_cheat: bool, release_mode: str = "disabled", collect_mode="disabled", self,
countdown_mode: str = "auto", remaining_mode: str = "disabled", auto_shutdown: typing.SupportsFloat = 0, host: str,
compatibility: int = 2, log_network: bool = False, logger: logging.Logger = logging.getLogger()): port: int,
server_password: str,
password: str,
location_check_points: int,
hint_cost: int,
item_cheat: bool,
release_mode: str = "disabled",
collect_mode="disabled",
countdown_mode: str = "auto",
remaining_mode: str = "disabled",
auto_shutdown: typing.SupportsFloat = 0,
compatibility: int = 2,
log_network: bool = False,
games_package_cache: GamesPackageCache | None = None,
logger: logging.Logger = logging.getLogger(),
) -> None:
self.logger = logger self.logger = logger
super(Context, self).__init__() super(Context, self).__init__()
self.slot_info = {} self.slot_info = {}
@@ -306,6 +324,7 @@ class Context:
self.save_dirty = False self.save_dirty = False
self.tags = ['AP'] self.tags = ['AP']
self.games: typing.Dict[int, str] = {} self.games: typing.Dict[int, str] = {}
self.played_games = set()
self.minimum_client_versions: typing.Dict[int, Version] = {} self.minimum_client_versions: typing.Dict[int, Version] = {}
self.seed_name = "" self.seed_name = ""
self.groups = {} self.groups = {}
@@ -315,9 +334,10 @@ class Context:
self.stored_data_notification_clients = collections.defaultdict(weakref.WeakSet) self.stored_data_notification_clients = collections.defaultdict(weakref.WeakSet)
self.read_data = {} self.read_data = {}
self.spheres = [] self.spheres = []
self.games_package_cache = games_package_cache or GamesPackageCache()
# init empty to satisfy linter, I suppose # init empty to satisfy linter, I suppose
self.gamespackage = {} self.reduced_games_package = {}
self.checksums = {} self.checksums = {}
self.item_name_groups = {} self.item_name_groups = {}
self.location_name_groups = {} self.location_name_groups = {}
@@ -329,50 +349,11 @@ class Context:
lambda: Utils.KeyedDefaultDict(lambda code: f'Unknown location (ID:{code})')) lambda: Utils.KeyedDefaultDict(lambda code: f'Unknown location (ID:{code})'))
self.non_hintable_names = collections.defaultdict(frozenset) self.non_hintable_names = collections.defaultdict(frozenset)
self._load_game_data()
# Data package retrieval
def _load_game_data(self):
import worlds
self.gamespackage = worlds.network_data_package["games"]
self.item_name_groups = {world_name: world.item_name_groups for world_name, world in
worlds.AutoWorldRegister.world_types.items()}
self.location_name_groups = {world_name: world.location_name_groups for world_name, world in
worlds.AutoWorldRegister.world_types.items()}
for world_name, world in worlds.AutoWorldRegister.world_types.items():
self.non_hintable_names[world_name] = world.hint_blacklist
for game_package in self.gamespackage.values():
# remove groups from data sent to clients
del game_package["item_name_groups"]
del game_package["location_name_groups"]
def _init_game_data(self):
for game_name, game_package in self.gamespackage.items():
if "checksum" in game_package:
self.checksums[game_name] = game_package["checksum"]
for item_name, item_id in game_package["item_name_to_id"].items():
self.item_names[game_name][item_id] = item_name
for location_name, location_id in game_package["location_name_to_id"].items():
self.location_names[game_name][location_id] = location_name
self.all_item_and_group_names[game_name] = \
set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name])
self.all_location_and_group_names[game_name] = \
set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, []))
archipelago_item_names = self.item_names["Archipelago"]
archipelago_location_names = self.location_names["Archipelago"]
for game in [game_name for game_name in self.gamespackage if game_name != "Archipelago"]:
# Add Archipelago items and locations to each data package.
self.item_names[game].update(archipelago_item_names)
self.location_names[game].update(archipelago_location_names)
def item_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]: def item_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]:
return self.gamespackage[game]["item_name_to_id"] if game in self.gamespackage else None return self.reduced_games_package[game]["item_name_to_id"] if game in self.reduced_games_package else None
def location_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]: def location_names_for_game(self, game: str) -> typing.Optional[typing.Dict[str, int]]:
return self.gamespackage[game]["location_name_to_id"] if game in self.gamespackage else None return self.reduced_games_package[game]["location_name_to_id"] if game in self.reduced_games_package else None
# General networking # General networking
async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool: async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
@@ -482,19 +463,17 @@ class Context:
with open(multidatapath, 'rb') as f: with open(multidatapath, 'rb') as f:
data = f.read() data = f.read()
self._load(self.decompress(data), {}, use_embedded_server_options) self._load(self.decompress(data), use_embedded_server_options)
self.data_filename = multidatapath self.data_filename = multidatapath
@staticmethod @staticmethod
def decompress(data: bytes) -> dict: def decompress(data: bytes) -> typing.Any:
format_version = data[0] format_version = data[0]
if format_version > 3: if format_version > 3:
raise Utils.VersionException("Incompatible multidata.") raise Utils.VersionException("Incompatible multidata.")
return restricted_loads(zlib.decompress(data[1:])) return restricted_loads(zlib.decompress(data[1:]))
def _load(self, decoded_obj: MultiData, game_data_packages: typing.Dict[str, typing.Any], def _load(self, decoded_obj: MultiData, use_embedded_server_options: bool) -> None:
use_embedded_server_options: bool):
self.read_data = {} self.read_data = {}
# there might be a better place to put this. # there might be a better place to put this.
race_mode = decoded_obj.get("race_mode", 0) race_mode = decoded_obj.get("race_mode", 0)
@@ -515,6 +494,7 @@ class Context:
self.slot_info = decoded_obj["slot_info"] self.slot_info = decoded_obj["slot_info"]
self.games = {slot: slot_info.game for slot, slot_info in self.slot_info.items()} self.games = {slot: slot_info.game for slot, slot_info in self.slot_info.items()}
self.played_games = {"Archipelago"} | {self.games[x] for x in range(1, len(self.games) + 1)}
self.groups = {slot: set(slot_info.group_members) for slot, slot_info in self.slot_info.items() self.groups = {slot: set(slot_info.group_members) for slot, slot_info in self.slot_info.items()
if slot_info.type == SlotType.group} if slot_info.type == SlotType.group}
@@ -559,18 +539,11 @@ class Context:
server_options = decoded_obj.get("server_options", {}) server_options = decoded_obj.get("server_options", {})
self._set_options(server_options) self._set_options(server_options)
# embedded data package # load and apply world data and (embedded) data package
for game_name, data in decoded_obj.get("datapackage", {}).items(): self._load_world_data()
if game_name in game_data_packages: self._load_data_package(decoded_obj.get("datapackage", {}))
data = game_data_packages[game_name]
self.logger.info(f"Loading embedded data package for game {game_name}")
self.gamespackage[game_name] = data
self.item_name_groups[game_name] = data["item_name_groups"]
if "location_name_groups" in data:
self.location_name_groups[game_name] = data["location_name_groups"]
del data["location_name_groups"]
del data["item_name_groups"] # remove from data package, but keep in self.item_name_groups
self._init_game_data() self._init_game_data()
for game_name, data in self.item_name_groups.items(): for game_name, data in self.item_name_groups.items():
self.read_data[f"item_name_groups_{game_name}"] = lambda lgame=game_name: self.item_name_groups[lgame] self.read_data[f"item_name_groups_{game_name}"] = lambda lgame=game_name: self.item_name_groups[lgame]
for game_name, data in self.location_name_groups.items(): for game_name, data in self.location_name_groups.items():
@@ -579,6 +552,55 @@ class Context:
# sorted access spheres # sorted access spheres
self.spheres = decoded_obj.get("spheres", []) self.spheres = decoded_obj.get("spheres", [])
def _load_world_data(self) -> None:
import worlds
for world_name, world in worlds.AutoWorldRegister.world_types.items():
# TODO: move hint_blacklist into GamesPackage?
self.non_hintable_names[world_name] = world.hint_blacklist
def _load_data_package(self, data_package: dict[str, GamesPackage]) -> None:
"""Populates reduced_games_package, item_name_groups, location_name_groups from static data and data_package"""
# NOTE: for worlds loaded from db, only checksum is set in GamesPackage, but this is handled by cache
for game_name in sorted(self.played_games):
if game_name in data_package:
self.logger.info(f"Loading embedded data package for game {game_name}")
data = self.games_package_cache.get(game_name, data_package[game_name])
else:
# NOTE: we still allow uploading a game without datapackage. Once that is changed, we could drop this.
data = self.games_package_cache.get_static(game_name)
(
self.reduced_games_package[game_name],
self.item_name_groups[game_name],
self.location_name_groups[game_name],
) = data
del self.games_package_cache # Not used past this point. Free memory.
def _init_game_data(self) -> None:
"""Update internal values from previously loaded data packages"""
for game_name, game_package in self.reduced_games_package.items():
if game_name not in self.played_games:
continue
if "checksum" in game_package:
self.checksums[game_name] = game_package["checksum"]
# NOTE: we could save more memory by moving the stuff below to data package cache as well
for item_name, item_id in game_package["item_name_to_id"].items():
self.item_names[game_name][item_id] = item_name
for location_name, location_id in game_package["location_name_to_id"].items():
self.location_names[game_name][location_id] = location_name
self.all_item_and_group_names[game_name] = \
set(game_package["item_name_to_id"]) | set(self.item_name_groups[game_name])
self.all_location_and_group_names[game_name] = \
set(game_package["location_name_to_id"]) | set(self.location_name_groups.get(game_name, []))
archipelago_item_names = self.item_names["Archipelago"]
archipelago_location_names = self.location_names["Archipelago"]
for game in [game_name for game_name in self.reduced_games_package if game_name != "Archipelago"]:
# Add Archipelago items and locations to each data package.
self.item_names[game].update(archipelago_item_names)
self.location_names[game].update(archipelago_location_names)
# saving # saving
def save(self, now=False) -> bool: def save(self, now=False) -> bool:
@@ -919,12 +941,10 @@ async def server(websocket: "ServerConnection", path: str = "/", ctx: Context =
async def on_client_connected(ctx: Context, client: Client): async def on_client_connected(ctx: Context, client: Client):
games = {ctx.games[x] for x in range(1, len(ctx.games) + 1)}
games.add("Archipelago")
await ctx.send_msgs(client, [{ await ctx.send_msgs(client, [{
'cmd': 'RoomInfo', 'cmd': 'RoomInfo',
'password': bool(ctx.password), 'password': bool(ctx.password),
'games': games, 'games': sorted(ctx.played_games),
# tags are for additional features in the communication. # tags are for additional features in the communication.
# Name them by feature or fork, as you feel is appropriate. # Name them by feature or fork, as you feel is appropriate.
'tags': ctx.tags, 'tags': ctx.tags,
@@ -933,8 +953,7 @@ async def on_client_connected(ctx: Context, client: Client):
'permissions': get_permissions(ctx), 'permissions': get_permissions(ctx),
'hint_cost': ctx.hint_cost, 'hint_cost': ctx.hint_cost,
'location_check_points': ctx.location_check_points, 'location_check_points': ctx.location_check_points,
'datapackage_checksums': {game: game_data["checksum"] for game, game_data 'datapackage_checksums': ctx.checksums,
in ctx.gamespackage.items() if game in games and "checksum" in game_data},
'seed_name': ctx.seed_name, 'seed_name': ctx.seed_name,
'time': time.time(), 'time': time.time(),
}]) }])
@@ -1940,25 +1959,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
await ctx.send_msgs(client, reply) await ctx.send_msgs(client, reply)
elif cmd == "GetDataPackage": elif cmd == "GetDataPackage":
exclusions = args.get("exclusions", []) games = {
if "games" in args: name: game_data for name, game_data in ctx.reduced_games_package.items()
games = {name: game_data for name, game_data in ctx.gamespackage.items() if name in set(args.get("games", []))
if name in set(args.get("games", []))} }
await ctx.send_msgs(client, [{"cmd": "DataPackage", await ctx.send_msgs(client, [{"cmd": "DataPackage", "data": {"games": games}}])
"data": {"games": games}}])
# TODO: remove exclusions behaviour around 0.5.0
elif exclusions:
exclusions = set(exclusions)
games = {name: game_data for name, game_data in ctx.gamespackage.items()
if name not in exclusions}
package = {"games": games}
await ctx.send_msgs(client, [{"cmd": "DataPackage",
"data": package}])
else:
await ctx.send_msgs(client, [{"cmd": "DataPackage",
"data": {"games": ctx.gamespackage}}])
elif client.auth: elif client.auth:
if cmd == "ConnectUpdate": if cmd == "ConnectUpdate":
+65 -76
View File
@@ -13,6 +13,7 @@ import threading
import time import time
import typing import typing
import sys import sys
from asyncio import AbstractEventLoop
import websockets import websockets
from pony.orm import commit, db_session, select from pony.orm import commit, db_session, select
@@ -24,8 +25,10 @@ from MultiServer import (
server_per_message_deflate_factory, server_per_message_deflate_factory,
) )
from Utils import restricted_loads, cache_argsless from Utils import restricted_loads, cache_argsless
from NetUtils import GamesPackage
from apmw.webhost.customserver.gamespackagecache import DBGamesPackageCache
from .locker import Locker from .locker import Locker
from .models import Command, GameDataPackage, Room, db from .models import Command, Room, db
class CustomClientMessageProcessor(ClientMessageProcessor): class CustomClientMessageProcessor(ClientMessageProcessor):
@@ -62,18 +65,39 @@ class DBCommandProcessor(ServerCommandProcessor):
class WebHostContext(Context): class WebHostContext(Context):
room_id: int room_id: int
video: dict[tuple[int, int], tuple[str, str]]
main_loop: AbstractEventLoop
static_server_data: StaticServerData
def __init__(self, static_server_data: dict, logger: logging.Logger): def __init__(
self,
static_server_data: StaticServerData,
games_package_cache: DBGamesPackageCache,
logger: logging.Logger,
) -> None:
# static server data is used during _load_game_data to load required data, # static server data is used during _load_game_data to load required data,
# without needing to import worlds system, which takes quite a bit of memory # without needing to import worlds system, which takes quite a bit of memory
self.static_server_data = static_server_data super(WebHostContext, self).__init__(
super(WebHostContext, self).__init__("", 0, "", "", 1, "",
40, True, "enabled", "enabled", 0,
"enabled", 0, 2, logger=logger) "",
del self.static_server_data "",
self.main_loop = asyncio.get_running_loop() 1,
self.video = {} 40,
True,
"enabled",
"enabled",
"enabled",
0,
2,
games_package_cache=games_package_cache,
logger=logger,
)
self.tags = ["AP", "WebHost"] self.tags = ["AP", "WebHost"]
self.video = {}
self.main_loop = asyncio.get_running_loop()
self.static_server_data = static_server_data
self.games_package_cache = games_package_cache
def __del__(self): def __del__(self):
try: try:
@@ -83,12 +107,6 @@ class WebHostContext(Context):
except ImportError: except ImportError:
self.logger.debug("Context destroyed") self.logger.debug("Context destroyed")
def _load_game_data(self):
for key, value in self.static_server_data.items():
# NOTE: attributes are mutable and shared, so they will have to be copied before being modified
setattr(self, key, value)
self.non_hintable_names = collections.defaultdict(frozenset, self.non_hintable_names)
async def listen_to_db_commands(self): async def listen_to_db_commands(self):
cmdprocessor = DBCommandProcessor(self) cmdprocessor = DBCommandProcessor(self)
@@ -118,42 +136,14 @@ class WebHostContext(Context):
self.port = get_random_port() self.port = get_random_port()
multidata = self.decompress(room.seed.multidata) multidata = self.decompress(room.seed.multidata)
game_data_packages = {} return self._load(multidata, True)
static_gamespackage = self.gamespackage # this is shared across all rooms def _load_world_data(self):
static_item_name_groups = self.item_name_groups # Use static_server_data, but skip static data package since that is in cache anyway.
static_location_name_groups = self.location_name_groups # Also NOT importing worlds here!
self.gamespackage = {"Archipelago": static_gamespackage.get("Archipelago", {})} # this may be modified by _load # FIXME: does this copy the non_hintable_names (also for games not part of the room)?
self.item_name_groups = {"Archipelago": static_item_name_groups.get("Archipelago", {})} self.non_hintable_names = collections.defaultdict(frozenset, self.static_server_data["non_hintable_names"])
self.location_name_groups = {"Archipelago": static_location_name_groups.get("Archipelago", {})} del self.static_server_data # Not used past this point. Free memory.
missing_checksum = False
for game in list(multidata.get("datapackage", {})):
game_data = multidata["datapackage"][game]
if "checksum" in game_data:
if static_gamespackage.get(game, {}).get("checksum") == game_data["checksum"]:
# non-custom. remove from multidata and use static data
# games package could be dropped from static data once all rooms embed data package
del multidata["datapackage"][game]
else:
row = GameDataPackage.get(checksum=game_data["checksum"])
if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8. multidata should be complete
game_data_packages[game] = restricted_loads(row.data)
continue
else:
self.logger.warning(f"Did not find game_data_package for {game}: {game_data['checksum']}")
else:
missing_checksum = True # Game rolled on old AP and will load data package from multidata
self.gamespackage[game] = static_gamespackage.get(game, {})
self.item_name_groups[game] = static_item_name_groups.get(game, {})
self.location_name_groups[game] = static_location_name_groups.get(game, {})
if not game_data_packages and not missing_checksum:
# all static -> use the static dicts directly
self.gamespackage = static_gamespackage
self.item_name_groups = static_item_name_groups
self.location_name_groups = static_location_name_groups
return self._load(multidata, game_data_packages, True)
def init_save(self, enabled: bool = True): def init_save(self, enabled: bool = True):
self.saving = enabled self.saving = enabled
@@ -185,34 +175,23 @@ def get_random_port():
return random.randint(49152, 65535) return random.randint(49152, 65535)
class StaticServerData(typing.TypedDict, total=True):
non_hintable_names: dict[str, typing.AbstractSet[str]]
games_package: dict[str, GamesPackage]
@cache_argsless @cache_argsless
def get_static_server_data() -> dict: def get_static_server_data() -> StaticServerData:
import worlds import worlds
data = {
return {
"non_hintable_names": { "non_hintable_names": {
world_name: world.hint_blacklist world_name: world.hint_blacklist
for world_name, world in worlds.AutoWorldRegister.world_types.items() for world_name, world in worlds.AutoWorldRegister.world_types.items()
}, },
"gamespackage": { "games_package": worlds.network_data_package["games"]
world_name: {
key: value
for key, value in game_package.items()
if key not in ("item_name_groups", "location_name_groups")
}
for world_name, game_package in worlds.network_data_package["games"].items()
},
"item_name_groups": {
world_name: world.item_name_groups
for world_name, world in worlds.AutoWorldRegister.world_types.items()
},
"location_name_groups": {
world_name: world.location_name_groups
for world_name, world in worlds.AutoWorldRegister.world_types.items()
},
} }
return data
def set_up_logging(room_id) -> logging.Logger: def set_up_logging(room_id) -> logging.Logger:
import os import os
@@ -245,9 +224,18 @@ def tear_down_logging(room_id):
del logging.Logger.manager.loggerDict[logger_name] del logging.Logger.manager.loggerDict[logger_name]
def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, def run_server_process(
cert_file: typing.Optional[str], cert_key_file: typing.Optional[str], name: str,
host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue): ponyconfig: dict[str, typing.Any],
static_server_data: StaticServerData,
cert_file: typing.Optional[str],
cert_key_file: typing.Optional[str],
host: str,
rooms_to_run: multiprocessing.Queue,
rooms_shutting_down: multiprocessing.Queue,
) -> None:
import gc
from setproctitle import setproctitle from setproctitle import setproctitle
setproctitle(name) setproctitle(name)
@@ -263,6 +251,9 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit)) resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit))
del resource, file_limit del resource, file_limit
# prime the data package cache with static data
games_package_cache = DBGamesPackageCache(static_server_data["games_package"])
# establish DB connection for multidata and multisave # establish DB connection for multidata and multisave
db.bind(**ponyconfig) db.bind(**ponyconfig)
db.generate_mapping(check_tables=False) db.generate_mapping(check_tables=False)
@@ -270,8 +261,6 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
if "worlds" in sys.modules: if "worlds" in sys.modules:
raise Exception("Worlds system should not be loaded in the custom server.") raise Exception("Worlds system should not be loaded in the custom server.")
import gc
if not cert_file: if not cert_file:
def get_ssl_context(): def get_ssl_context():
return None return None
@@ -296,7 +285,7 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
with Locker(f"RoomLocker {room_id}"): with Locker(f"RoomLocker {room_id}"):
try: try:
logger = set_up_logging(room_id) logger = set_up_logging(room_id)
ctx = WebHostContext(static_server_data, logger) ctx = WebHostContext(static_server_data, games_package_cache, logger)
ctx.load(room_id) ctx.load(room_id)
ctx.init_save() ctx.init_save()
assert ctx.server is None assert ctx.server is None
View File
View File
+96
View File
@@ -0,0 +1,96 @@
import typing as t
from weakref import WeakValueDictionary
from NetUtils import GamesPackage
GameAndChecksum = tuple[str, str | None]
ItemNameGroups = dict[str, list[str]]
LocationNameGroups = dict[str, list[str]]
K = t.TypeVar("K")
V = t.TypeVar("V")
class DictLike(dict[K, V]):
__slots__ = ("__weakref__",)
class GamesPackageCache:
# NOTE: this uses 3 separate collections because unpacking the get() result would end the container lifetime
_reduced_games_packages: WeakValueDictionary[GameAndChecksum, GamesPackage]
"""Does not include item_name_groups nor location_name_groups"""
_item_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]]
_location_name_groups: WeakValueDictionary[GameAndChecksum, dict[str, list[str]]]
def __init__(self) -> None:
self._reduced_games_packages = WeakValueDictionary()
self._item_name_groups = WeakValueDictionary()
self._location_name_groups = WeakValueDictionary()
def _get(
self,
cache_key: GameAndChecksum,
) -> tuple[GamesPackage | None, ItemNameGroups | None, LocationNameGroups | None]:
if cache_key[1] is None:
return None, None, None
return (
self._reduced_games_packages.get(cache_key, None),
self._item_name_groups.get(cache_key, None),
self._location_name_groups.get(cache_key, None),
)
def get(
self,
game: str,
full_games_package: GamesPackage,
) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
"""Loads and caches embedded data package provided by multidata"""
cache_key = (game, full_games_package.get("checksum", None))
cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups = self._get(cache_key)
if cached_reduced_games_package is None:
cached_reduced_games_package = t.cast(
t.Any,
DictLike(
{
"item_name_to_id": full_games_package["item_name_to_id"],
"location_name_to_id": full_games_package["location_name_to_id"],
"checksum": full_games_package.get("checksum", None),
}
),
)
if cache_key[1] is not None: # only cache if checksum is available
self._reduced_games_packages[cache_key] = cached_reduced_games_package
if cached_item_name_groups is None:
# optimize strings to be references instead of copies
item_names = {name: name for name in cached_reduced_games_package["item_name_to_id"].keys()}
cached_item_name_groups = DictLike(
{
group_name: [item_names.get(item_name, item_name) for item_name in group_items]
for group_name, group_items in full_games_package["item_name_groups"].items()
}
)
if cache_key[1] is not None: # only cache if checksum is available
self._item_name_groups[cache_key] = cached_item_name_groups
if cached_location_name_groups is None:
# optimize strings to be references instead of copies
location_names = {name: name for name in cached_reduced_games_package["location_name_to_id"].keys()}
cached_location_name_groups = DictLike(
{
group_name: [location_names.get(location_name, location_name) for location_name in group_locations]
for group_name, group_locations in full_games_package.get("location_name_groups", {}).items()
}
)
if cache_key[1] is not None: # only cache if checksum is available
self._location_name_groups[cache_key] = cached_location_name_groups
return cached_reduced_games_package, cached_item_name_groups, cached_location_name_groups
def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
"""Loads legacy data package from installed worlds"""
import worlds
return self.get(game, worlds.network_data_package["games"][game])
View File
@@ -0,0 +1,42 @@
from typing_extensions import override
from NetUtils import GamesPackage
from Utils import restricted_loads
from apmw.multiserver.gamespackagecache import GamesPackageCache, ItemNameGroups, LocationNameGroups
class DBGamesPackageCache(GamesPackageCache):
_static: dict[str, tuple[GamesPackage, ItemNameGroups, LocationNameGroups]]
def __init__(self, static_games_package: dict[str, GamesPackage]) -> None:
super().__init__()
self._static = {
game: GamesPackageCache.get(self, game, games_package)
for game, games_package in static_games_package.items()
}
@override
def get(
self,
game: str,
full_games_package: GamesPackage,
) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
# for games started on webhost, full_games_package is likely unpopulated and only has the checksum field
cache_key = (game, full_games_package.get("checksum", None))
cached = self._get(cache_key)
if any(value is None for value in cached):
if "checksum" not in full_games_package:
return super().get(game, full_games_package) # no checksum, assume fully populated
from WebHostLib.models import GameDataPackage
row: GameDataPackage | None = GameDataPackage.get(checksum=full_games_package["checksum"])
if row: # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8 ...
return super().get(game, restricted_loads(row.data))
return super().get(game, full_games_package) # ... in which case full_games_package should be populated
return cached # type: ignore # mypy doesn't understand any value is None
@override
def get_static(self, game: str) -> tuple[GamesPackage, ItemNameGroups, LocationNameGroups]:
return self._static[game]
View File
+132
View File
@@ -0,0 +1,132 @@
import typing as t
from copy import deepcopy
from unittest import TestCase
from typing_extensions import override
import NetUtils
from NetUtils import GamesPackage
from apmw.multiserver.gamespackagecache import GamesPackageCache
class GamesPackageCacheTest(TestCase):
cache: GamesPackageCache
any_game: t.ClassVar[str] = "APQuest"
example_games_package: GamesPackage = {
"item_name_to_id": {"Item 1": 1},
"item_name_groups": {"Everything": ["Item 1"]},
"location_name_to_id": {"Location 1": 1},
"location_name_groups": {"Everywhere": ["Location 1"]},
"checksum": "1234",
}
@override
def setUp(self) -> None:
self.cache = GamesPackageCache()
def test_get_static_is_same(self) -> None:
"""Tests that get_static returns the same objects twice"""
reduced_games_package1, item_name_groups1, location_name_groups1 = self.cache.get_static(self.any_game)
reduced_games_package2, item_name_groups2, location_name_groups2 = self.cache.get_static(self.any_game)
self.assertIs(reduced_games_package1, reduced_games_package2)
self.assertIs(item_name_groups1, item_name_groups2)
self.assertIs(location_name_groups1, location_name_groups2)
def test_get_static_data_format(self) -> None:
"""Tests that get_static returns data in the correct format"""
reduced_games_package, item_name_groups, location_name_groups = self.cache.get_static(self.any_game)
self.assertTrue(reduced_games_package["checksum"])
self.assertTrue(reduced_games_package["item_name_to_id"])
self.assertTrue(reduced_games_package["location_name_to_id"])
self.assertNotIn("item_name_groups", reduced_games_package)
self.assertNotIn("location_name_groups", reduced_games_package)
self.assertTrue(item_name_groups["Everything"])
self.assertTrue(location_name_groups["Everywhere"])
def test_get_static_is_serializable(self) -> None:
"""Tests that get_static returns data that can be serialized"""
NetUtils.encode(self.cache.get_static(self.any_game))
def test_get_static_missing_raises(self) -> None:
"""Tests that get_static raises KeyError if the world is missing"""
with self.assertRaises(KeyError):
_ = self.cache.get_static("Does not exist")
def test_eviction(self) -> None:
"""Tests that unused items get evicted from cache"""
game_name = "Test"
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, self.example_games_package)
self.assertTrue(data)
self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages))
del data
if len(self.cache._reduced_games_packages) != before_add: # gc.collect() may not even be required
import gc
gc.collect()
self.assertEqual(before_add, len(self.cache._reduced_games_packages))
def test_get_required_field(self) -> None:
"""Tests that missing required field raises a KeyError"""
for field in ("item_name_to_id", "location_name_to_id", "item_name_groups"):
with self.subTest(field=field):
games_package = deepcopy(self.example_games_package)
del games_package[field] # type: ignore
with self.assertRaises(KeyError):
_ = self.cache.get(self.any_game, games_package)
def test_get_optional_properties(self) -> None:
"""Tests that missing optional field works"""
for field in ("checksum", "location_name_groups"):
with self.subTest(field=field):
games_package = deepcopy(self.example_games_package)
del games_package[field] # type: ignore
_, item_name_groups, location_name_groups = self.cache.get(self.any_game, games_package)
self.assertTrue(item_name_groups)
self.assertEqual(field != "location_name_groups", bool(location_name_groups))
def test_item_name_deduplication(self) -> None:
n = 1
s1 = f"Item {n}"
s2 = f"Item {n}"
# check if the deduplication is actually gonna do anything
self.assertIsNot(s1, s2)
self.assertEqual(s1, s2)
# do the thing
game_name = "Test"
games_package: GamesPackage = {
"item_name_to_id": {s1: n},
"item_name_groups": {"Everything": [s2]},
"location_name_to_id": {},
"location_name_groups": {},
"checksum": "1234",
}
reduced_games_package, item_name_groups, location_name_groups = self.cache.get(game_name, games_package)
self.assertIs(
next(iter(reduced_games_package["item_name_to_id"].keys())),
item_name_groups["Everything"][0],
)
def test_location_name_deduplication(self) -> None:
n = 1
s1 = f"Location {n}"
s2 = f"Location {n}"
# check if the deduplication is actually gonna do anything
self.assertIsNot(s1, s2)
self.assertEqual(s1, s2)
# do the thing
game_name = "Test"
games_package: GamesPackage = {
"item_name_to_id": {},
"item_name_groups": {},
"location_name_to_id": {s1: n},
"location_name_groups": {"Everywhere": [s2]},
"checksum": "1234",
}
reduced_games_package, item_name_groups, location_name_groups = self.cache.get(game_name, games_package)
self.assertIs(
next(iter(reduced_games_package["location_name_to_id"].keys())),
location_name_groups["Everywhere"][0],
)
@@ -0,0 +1,147 @@
import typing as t
from copy import deepcopy
from typing_extensions import override
from test.multiserver.test_gamespackage_cache import GamesPackageCacheTest
import Utils
from NetUtils import GamesPackage
from apmw.webhost.customserver.gamespackagecache import DBGamesPackageCache
class FakeGameDataPackage:
_rows: "t.ClassVar[dict[str, FakeGameDataPackage]]" = {}
data: bytes
@classmethod
def get(cls, checksum: str) -> "FakeGameDataPackage | None":
return cls._rows.get(checksum, None)
@classmethod
def add(cls, checksum: str, full_games_package: GamesPackage) -> None:
row = FakeGameDataPackage()
row.data = Utils.restricted_dumps(full_games_package)
cls._rows[checksum] = row
class DBGamesPackageCacheTest(GamesPackageCacheTest):
cache: DBGamesPackageCache
any_game: t.ClassVar[str] = "My Game"
static_data: t.ClassVar[dict[str, GamesPackage]] = { # noqa: pycharm doesn't understand this
"My Game": {
"item_name_to_id": {"Item 1": 1},
"location_name_to_id": {"Location 1": 1},
"item_name_groups": {"Everything": ["Item 1"]},
"location_name_groups": {"Everywhere": ["Location 1"]},
"checksum": "2345",
}
}
orig_db_type: t.ClassVar[type]
@override
@classmethod
def setUpClass(cls) -> None:
import WebHostLib.models
cls.orig_db_type = WebHostLib.models.GameDataPackage
WebHostLib.models.GameDataPackage = FakeGameDataPackage # type: ignore
@override
def setUp(self) -> None:
self.cache = DBGamesPackageCache(self.static_data)
@override
@classmethod
def tearDownClass(cls) -> None:
import WebHostLib.models
WebHostLib.models.GameDataPackage = cls.orig_db_type # type: ignore
def assert_conversion(
self,
full_games_package: GamesPackage,
reduced_games_package: dict[str, t.Any],
item_name_groups: dict[str, t.Any],
location_name_groups: dict[str, t.Any],
) -> None:
for key in ("item_name_to_id", "location_name_to_id", "checksum"):
if key in full_games_package:
self.assertEqual(reduced_games_package[key], full_games_package[key]) # noqa: pycharm
self.assertEqual(item_name_groups, full_games_package["item_name_groups"])
self.assertEqual(location_name_groups, full_games_package["location_name_groups"])
def assert_static_conversion(
self,
full_games_package: GamesPackage,
reduced_games_package: dict[str, t.Any],
item_name_groups: dict[str, t.Any],
location_name_groups: dict[str, t.Any],
) -> None:
self.assert_conversion(full_games_package, reduced_games_package, item_name_groups, location_name_groups)
for key in ("item_name_to_id", "location_name_to_id", "checksum"):
self.assertIs(reduced_games_package[key], full_games_package[key]) # noqa: pycharm
def test_get_static_contents(self) -> None:
"""Tests that get_static returns the correct data"""
reduced_games_package, item_name_groups, location_name_groups = self.cache.get_static(self.any_game)
for key in ("item_name_to_id", "location_name_to_id", "checksum"):
self.assertIs(reduced_games_package[key], self.static_data[self.any_game][key]) # noqa: pycharm
self.assertEqual(item_name_groups, self.static_data[self.any_game]["item_name_groups"])
self.assertEqual(location_name_groups, self.static_data[self.any_game]["location_name_groups"])
def test_static_not_evicted(self) -> None:
"""Tests that static data is not evicted from cache during gc"""
import gc
game_name = next(iter(self.static_data.keys()))
ids = [id(o) for o in self.cache.get_static(game_name)]
gc.collect()
self.assertEqual(ids, [id(o) for o in self.cache.get_static(game_name)])
def test_get_is_static(self) -> None:
"""Tests that a get with correct checksum return the static items"""
# NOTE: this is only true for the DB cache, not the "regular" one, since we want to avoid loading worlds there
cks: GamesPackage = {"checksum": self.static_data[self.any_game]["checksum"]} # noqa: pycharm doesn't like this
reduced_games_package1, item_name_groups1, location_name_groups1 = self.cache.get(self.any_game, cks)
reduced_games_package2, item_name_groups2, location_name_groups2 = self.cache.get_static(self.any_game)
self.assertIs(reduced_games_package1, reduced_games_package2)
self.assertEqual(location_name_groups1, location_name_groups2)
self.assertEqual(item_name_groups1, item_name_groups2)
def test_get_from_db(self) -> None:
"""Tests that a get with only checksum will load the full data from db and is cached"""
game_name = "Another Game"
full_games_package = deepcopy(self.static_data[self.any_game])
full_games_package["checksum"] = "3456"
cks: GamesPackage = {"checksum": full_games_package["checksum"]} # noqa: pycharm doesn't like this
FakeGameDataPackage.add(full_games_package["checksum"], full_games_package)
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, cks)
self.assert_conversion(full_games_package, *data) # type: ignore
self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages))
def test_get_missing_from_db_uses_full_games_package(self) -> None:
"""Tests that a get with full data (missing from db) will use the full data and is cached"""
game_name = "Yet Another Game"
full_games_package = deepcopy(self.static_data[self.any_game])
full_games_package["checksum"] = "4567"
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, full_games_package)
self.assert_conversion(full_games_package, *data) # type: ignore
self.assertEqual(before_add + 1, len(self.cache._reduced_games_packages))
def test_get_without_checksum_uses_full_games_package(self) -> None:
"""Tests that a get with full data and no checksum will use the full data and is not cached"""
game_name = "Yet Another Game"
full_games_package = deepcopy(self.static_data[self.any_game])
del full_games_package["checksum"]
before_add = len(self.cache._reduced_games_packages)
data = self.cache.get(game_name, full_games_package)
self.assert_conversion(full_games_package, *data) # type: ignore
self.assertEqual(before_add, len(self.cache._reduced_games_packages))
def test_get_missing_from_db_raises(self) -> None:
"""Tests that a get that requires a row to exist raise an exception if it doesn't"""
with self.assertRaises(Exception):
_ = self.cache.get("Does not exist", {"checksum": "0000"})