diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index 29f9f0402d..c2d9690a12 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -210,7 +210,8 @@ def parse_game_ports(game_ports: tuple[str | int]): def create_random_port_socket(game_ports: tuple[str | int], host: str) -> socket.socket: parsed_ports, weights, length, ephemeral_allowed = parse_game_ports(game_ports) - port_ranges = random.choices(parsed_ports, cum_weights=weights, k=len(parsed_ports)) + # try to randomize the order of parsed ports with weights, but don't have duplicates of them + port_ranges = list(dict.fromkeys(random.choices(parsed_ports, weights=weights, k=len(parsed_ports)) + parsed_ports)) remaining = 1024 for r in port_ranges: r_length = len(r) @@ -342,8 +343,8 @@ def tear_down_logging(room_id): def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, cert_file: typing.Optional[str], cert_key_file: typing.Optional[str], - host: str, game_ports: list, rooms_to_run: multiprocessing.Queue, - rooms_shutting_down: multiprocessing.Queue): + host: str, game_ports: typing.Iterable[str | int], + rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue): from setproctitle import setproctitle setproctitle(name) @@ -359,6 +360,10 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit)) del resource, file_limit + # convert to tuple because its hashable + if not isinstance(game_ports, tuple): + game_ports = tuple(game_ports) + # establish DB connection for multidata and multisave db.bind(**ponyconfig) db.generate_mapping(check_tables=False) @@ -411,8 +416,7 @@ def run_server_process(name: str, ponyconfig: dict, static_server_data: dict, if ctx.port == 0: ctx.server = websockets.serve( functools.partial(server, ctx=ctx), - # convert to tuple because its hashable - sock=create_random_port_socket(tuple(game_ports), ctx.host), + sock=create_random_port_socket(game_ports, ctx.host), ssl=get_ssl_context(), extensions=[server_per_message_deflate_factory], )