diff --git a/WebHostLib/__init__.py b/WebHostLib/__init__.py index d10c17bff8..fb80b3d9c3 100644 --- a/WebHostLib/__init__.py +++ b/WebHostLib/__init__.py @@ -42,6 +42,7 @@ app.config["SELFLAUNCH"] = True # application process is in charge of launching app.config["SELFLAUNCHCERT"] = None # can point to a SSL Certificate to encrypt Room websocket connections app.config["SELFLAUNCHKEY"] = None # can point to a SSL Certificate Key to encrypt Room websocket connections app.config["SELFGEN"] = True # application process is in charge of scheduling Generations. +app.config["GAME_PORTS"] = ["49152-65535", 0] # at what amount of worlds should scheduling be used, instead of rolling in the web-thread app.config["JOB_THRESHOLD"] = 1 # after what time in seconds should generation be aborted, freeing the queue slot. Can be set to None to disable. diff --git a/WebHostLib/autolauncher.py b/WebHostLib/autolauncher.py index b48c6a8cbb..bd8c18b4c7 100644 --- a/WebHostLib/autolauncher.py +++ b/WebHostLib/autolauncher.py @@ -187,6 +187,7 @@ class MultiworldInstance(): self.cert = config["SELFLAUNCHCERT"] self.key = config["SELFLAUNCHKEY"] self.host = config["HOST_ADDRESS"] + self.game_ports = config["GAME_PORTS"] self.rooms_to_start = multiprocessing.Queue() self.rooms_shutting_down = multiprocessing.Queue() self.name = f"MultiHoster{id}" @@ -197,7 +198,7 @@ class MultiworldInstance(): process = multiprocessing.Process(group=None, target=run_server_process, args=(self.name, self.ponyconfig, get_static_server_data(), - self.cert, self.key, self.host, + self.cert, self.key, self.host, self.game_ports, self.rooms_to_start, self.rooms_shutting_down), name=self.name) process.start() diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index 2cade4960d..c332312486 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -4,6 +4,7 @@ import asyncio import collections import datetime import functools +import itertools import logging import multiprocessing import pickle @@ -15,6 +16,7 @@ import typing import sys from asyncio import AbstractEventLoop +import psutil import websockets from pony.orm import commit, db_session, select @@ -133,7 +135,7 @@ class WebHostContext(Context): if room.last_port: self.port = room.last_port else: - self.port = get_random_port() + self.port = 0 multidata = self.decompress(room.seed.multidata) return self._load(multidata, True) @@ -171,8 +173,98 @@ class WebHostContext(Context): return d -def get_random_port(): - return random.randint(49152, 65535) +class GameRangePorts(typing.NamedTuple): + parsed_ports: list[range] + weights: list[int] + ephemeral_allowed: bool + + +@functools.cache +def parse_game_ports(game_ports: tuple[str | int, ...]) -> GameRangePorts: + parsed_ports: list[range] = [] + weights: list[int] = [] + ephemeral_allowed = False + total_length = 0 + + for item in game_ports: + if isinstance(item, str) and "-" in item: + start, end = map(int, item.split("-")) + x = range(start, end + 1) + total_length += len(x) + weights.append(total_length) + parsed_ports.append(x) + elif int(item) == 0: + ephemeral_allowed = True + else: + total_length += 1 + weights.append(total_length) + num = int(item) + parsed_ports.append(range(num, num + 1)) + + return GameRangePorts(parsed_ports, weights, ephemeral_allowed) + + +def weighted_random(ranges: list[range], cum_weights: list[int]) -> int: + [picked] = random.choices(ranges, cum_weights=cum_weights) + return random.randrange(picked.start, picked.stop, picked.step) + + +def create_random_port_socket(game_ports: tuple[str | int, ...], host: str) -> socket.socket: + parsed_ports, weights, ephemeral_allowed = parse_game_ports(game_ports) + used_ports = get_used_ports() + i = 1024 if len(parsed_ports) > 0 else 0 + while i > 0: + port_num = weighted_random(parsed_ports, weights) + if port_num in used_ports: + used_ports = get_used_ports() + continue + + i -= 0 + + try: + return socket.create_server((host, port_num)) + except OSError: + pass + + if ephemeral_allowed: + return socket.create_server((host, 0)) + + raise OSError(98, "No available ports") + + +def try_conns_per_process(p: psutil.Process) -> typing.Iterable[int]: + try: + return (c.laddr.port for c in p.net_connections("tcp4")) + except psutil.AccessDenied: + return () + + +def get_active_net_connections() -> typing.Iterable[int]: + # Don't even try to check if system using AIX + if psutil.AIX: + return () + + try: + return (c.laddr.port for c in psutil.net_connections("tcp4")) + # raises AccessDenied when done on macOS + except psutil.AccessDenied: + # flatten the list of iterables + return itertools.chain.from_iterable(map( + # get the net connections of the process and then map its ports + try_conns_per_process, + # this method has caching handled by psutil + psutil.process_iter(["net_connections"]) + )) + + +def get_used_ports(): + last_used_ports: tuple[frozenset[int], float] | None = getattr(get_used_ports, "last", None) + t_hash = round(time.time() / 90) # cache for 90 seconds + if last_used_ports is None or last_used_ports[1] != t_hash: + last_used_ports = (frozenset(get_active_net_connections()), t_hash) + setattr(get_used_ports, "last", last_used_ports) + + return last_used_ports[0] class StaticServerData(typing.TypedDict, total=True): @@ -231,6 +323,7 @@ def run_server_process( cert_file: typing.Optional[str], cert_key_file: typing.Optional[str], host: str, + game_ports: typing.Iterable[str | int], rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue, ) -> None: @@ -253,6 +346,8 @@ def run_server_process( # prime the data package cache with static data games_package_cache = DBGamesPackageCache(static_server_data["games_package"]) + # convert to tuple because its hashable + game_ports = tuple(game_ports) # establish DB connection for multidata and multisave db.bind(**ponyconfig) @@ -289,20 +384,26 @@ def run_server_process( ctx.load(room_id) ctx.init_save() assert ctx.server is None - try: + if ctx.port != 0: + try: + ctx.server = websockets.serve( + functools.partial(server, ctx=ctx), + ctx.host, + ctx.port, + ssl=get_ssl_context(), + extensions=[server_per_message_deflate_factory], + ) + await ctx.server + except OSError: + ctx.port = 0 + if ctx.port == 0: ctx.server = websockets.serve( functools.partial(server, ctx=ctx), - ctx.host, - ctx.port, + sock=create_random_port_socket(game_ports, ctx.host), ssl=get_ssl_context(), extensions=[server_per_message_deflate_factory], ) await ctx.server - except OSError: # likely port in use - ctx.server = websockets.serve( - functools.partial(server, ctx=ctx), ctx.host, 0, ssl=get_ssl_context()) - - await ctx.server port = 0 for wssocket in ctx.server.ws_server.sockets: socketname = wssocket.getsockname() @@ -377,7 +478,7 @@ def run_server_process( def run(self): while 1: - next_room = rooms_to_run.get(block=True, timeout=None) + next_room = rooms_to_run.get(block=True, timeout=None) gc.collect() task = asyncio.run_coroutine_threadsafe(start_room(next_room), loop) self._tasks.append(task) diff --git a/docs/webhost configuration sample.yaml b/docs/webhost configuration sample.yaml index 93094f1ce7..059faeeef9 100644 --- a/docs/webhost configuration sample.yaml +++ b/docs/webhost configuration sample.yaml @@ -17,6 +17,12 @@ # Web hosting port #PORT: 80 +# Ports used for game hosting. Values can be specific ports, port ranges or both. Default is: [49152-65535, 0] +# Zero means it will use a random free port if there is no port in the next 1024 randomly chosen ports from the range +# Examples of valid values: [40000-41000, 49152-65535] +# If ports within the range(s) are already in use, the WebHost will fallback to the default [49152-65535, 0] range. +#GAME_PORTS: [49152-65535, 0] + # Place where uploads go. #UPLOAD_FOLDER: uploads diff --git a/test/webhost/test_port_allocation.py b/test/webhost/test_port_allocation.py new file mode 100644 index 0000000000..d20e82295e --- /dev/null +++ b/test/webhost/test_port_allocation.py @@ -0,0 +1,86 @@ +import os +import unittest + +from Utils import is_macos +from WebHostLib.customserver import parse_game_ports, create_random_port_socket, get_used_ports + +ci = bool(os.environ.get("CI")) + + +class TestPortAllocating(unittest.TestCase): + def test_parse_game_ports(self) -> None: + """Ensure that game ports with ranges are parsed correctly""" + val = parse_game_ports(("1000-2000", "2000-5000", "1000-2000", 20, 40, "20", "0")) + + self.assertListEqual(val.parsed_ports, + [range(1000, 2001), range(2000, 5001), range(1000, 2001), range(20, 21), range(40, 41), + range(20, 21)], "The parsed game ports are not the expected values") + self.assertTrue(val.ephemeral_allowed, "The ephemeral allowed flag is not set even though it was passed") + self.assertListEqual(val.weights, [1001, 4002, 5003, 5004, 5005, 5006], + "Cumulative weights are not the expected value") + + val = parse_game_ports(()) + self.assertListEqual(val.parsed_ports, [], "Empty list of game port returned something") + self.assertFalse(val.ephemeral_allowed, "Empty list returned that ephemeral is allowed") + + val = parse_game_ports((0,)) + self.assertListEqual(val.parsed_ports, [], "Empty list of ranges returned something") + self.assertTrue(val.ephemeral_allowed, "List with just 0 is not allowing ephemeral ports") + + val = parse_game_ports((1,)) + self.assertEqual(val.parsed_ports, [range(1, 2)], "Parsed ports doesn't contain the expected values") + self.assertFalse(val.ephemeral_allowed, "List with just single port returned that ephemeral is allowed") + + def test_parse_game_port_errors(self) -> None: + """Ensure that game ports with incorrect values raise the expected error""" + with self.assertRaises(ValueError, msg="Negative numbers didn't get interpreted as an invalid range"): + parse_game_ports(tuple("-50215")) + with self.assertRaises(ValueError, msg="Text got interpreted as a valid number"): + parse_game_ports(tuple("dwafawg")) + with self.assertRaises( + ValueError, + msg="A range with an extra dash at the end didn't get interpreted as an invalid number because of it's end dash" + ): + parse_game_ports(tuple("20-21215-")) + with self.assertRaises(ValueError, msg="Text got interpreted as a valid number for the start of a range"): + parse_game_ports(tuple("f-21215")) + + def test_random_port_socket_edge_cases(self) -> None: + """Verify if edge cases on creation of random port socket is working fine""" + # Try giving an empty tuple and fail over it + with self.assertRaises(OSError) as err: + create_random_port_socket(tuple(), "127.0.0.1") + self.assertEqual(err.exception.errno, 98, "Raised an unexpected error code") + self.assertEqual(err.exception.strerror, "No available ports", "Raised an unexpected error string") + + # Try only having ephemeral ports enabled + try: + create_random_port_socket(("0",), "127.0.0.1").close() + except OSError as err: + self.assertEqual(err.errno, 98, "Raised an unexpected error code") + # If it returns our error string that means something is wrong with our code + self.assertNotEqual(err.strerror, "No available ports", + "Raised an unexpected error string") + + @unittest.skipUnless(ci, "can't guarantee free ports outside of CI") + def test_random_port_socket(self) -> None: + """Verify if returned sockets use the correct port ranges""" + sockets = [] + for _ in range(6): + socket = create_random_port_socket(("8080-8085",), "127.0.0.1") + sockets.append(socket) + _, port = socket.getsockname() + self.assertIn(port, range(8080, 8086), "Port of socket was not inside the expected range") + for s in sockets: + s.close() + + sockets.clear() + length = 5_000 if is_macos else (30_000 - len(get_used_ports())) + for _ in range(length): + socket = create_random_port_socket(("30000-65535",), "127.0.0.1") + sockets.append(socket) + _, port = socket.getsockname() + self.assertIn(port, range(30_000, 65536), "Port of socket was not inside the expected range") + + for s in sockets: + s.close()