diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index 7248bf3bac..e353cf2ab2 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -89,19 +89,24 @@ class WebHostContext(Context): setattr(self, key, value) self.non_hintable_names = collections.defaultdict(frozenset, self.non_hintable_names) - def listen_to_db_commands(self): + async def listen_to_db_commands(self): cmdprocessor = DBCommandProcessor(self) while not self.exit_event.is_set(): - with db_session: - commands = select(command for command in Command if command.room.id == self.room_id) - if commands: - for command in commands: - self.main_loop.call_soon_threadsafe(cmdprocessor, command.commandtext) - command.delete() - commit() - del commands - time.sleep(5) + await self.main_loop.run_in_executor(None, self._process_db_commands, cmdprocessor) + try: + await asyncio.wait_for(self.exit_event.wait(), 5) + except asyncio.TimeoutError: + pass + + def _process_db_commands(self, cmdprocessor): + with db_session: + commands = select(command for command in Command if command.room.id == self.room_id) + if commands: + for command in commands: + self.main_loop.call_soon_threadsafe(cmdprocessor, command.commandtext) + command.delete() + commit() @db_session def load(self, room_id: int): @@ -156,9 +161,9 @@ class WebHostContext(Context): with db_session: savegame_data = Room.get(id=self.room_id).multisave if savegame_data: - self.set_save(restricted_loads(Room.get(id=self.room_id).multisave)) + self.set_save(restricted_loads(savegame_data)) self._start_async_saving(atexit_save=False) - threading.Thread(target=self.listen_to_db_commands, daemon=True).start() + asyncio.create_task(self.listen_to_db_commands()) @db_session def _save(self, exit_save: bool = False) -> bool: @@ -229,6 +234,17 @@ def set_up_logging(room_id) -> logging.Logger: return logger +def tear_down_logging(room_id): + """Close logging handling for a room.""" + logger_name = f"RoomLogger {room_id}" + if logger_name in logging.Logger.manager.loggerDict: + logger = logging.getLogger(logger_name) + for handler in logger.handlers[:]: + logger.removeHandler(handler) + handler.close() + del logging.Logger.manager.loggerDict[logger_name] + + def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, cert_file: typing.Optional[str], cert_key_file: typing.Optional[str], host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue): @@ -343,12 +359,18 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, ctx.save_dirty = False # make sure the saving thread does not write to DB after final wakeup ctx.exit_event.set() # make sure the saving thread stops at some point # NOTE: async saving should probably be an async task and could be merged with shutdown_task + + if ctx.server and hasattr(ctx.server, "ws_server"): + ctx.server.ws_server.close() + await ctx.server.ws_server.wait_closed() + with db_session: # ensure the Room does not spin up again on its own, minute of safety buffer room = Room.get(id=room_id) room.last_activity = datetime.datetime.utcnow() - \ datetime.timedelta(minutes=1, seconds=room.timeout) del room + tear_down_logging(room_id) logging.info(f"Shutting down room {room_id} on {name}.") finally: await asyncio.sleep(5) diff --git a/test/webhost/test_host_room.py b/test/webhost/test_host_room.py index 4aa83e3b1c..0f43fea208 100644 --- a/test/webhost/test_host_room.py +++ b/test/webhost/test_host_room.py @@ -1,11 +1,22 @@ +import logging import os from uuid import UUID, uuid4, uuid5 from flask import url_for +from WebHostLib.customserver import set_up_logging, tear_down_logging from . import TestBase +def _cleanup_logger(room_id: UUID) -> None: + from Utils import user_path + tear_down_logging(room_id) + try: + os.unlink(user_path("logs", f"{room_id}.txt")) + except OSError: + pass + + class TestHostFakeRoom(TestBase): room_id: UUID log_filename: str @@ -39,7 +50,7 @@ class TestHostFakeRoom(TestBase): try: os.unlink(self.log_filename) - except FileNotFoundError: + except OSError: pass def test_display_log_missing_full(self) -> None: @@ -191,3 +202,27 @@ class TestHostFakeRoom(TestBase): with db_session: commands = select(command for command in Command if command.room.id == self.room_id) # type: ignore self.assertNotIn("/help", (command.commandtext for command in commands)) + + def test_logger_teardown(self) -> None: + """Verify that room loggers are removed from the global logging manager.""" + from WebHostLib.customserver import tear_down_logging + room_id = uuid4() + self.addCleanup(_cleanup_logger, room_id) + set_up_logging(room_id) + self.assertIn(f"RoomLogger {room_id}", logging.Logger.manager.loggerDict) + tear_down_logging(room_id) + self.assertNotIn(f"RoomLogger {room_id}", logging.Logger.manager.loggerDict) + + def test_handler_teardown(self) -> None: + """Verify that handlers for room loggers are closed by tear_down_logging.""" + from WebHostLib.customserver import tear_down_logging + room_id = uuid4() + self.addCleanup(_cleanup_logger, room_id) + logger = set_up_logging(room_id) + handlers = logger.handlers[:] + self.assertGreater(len(handlers), 0) + + tear_down_logging(room_id) + for handler in handlers: + if isinstance(handler, logging.FileHandler): + self.assertTrue(handler.stream is None or handler.stream.closed)