Merge remote-tracking branch 'imurx/custom-port-range' into active/rc-site

This commit is contained in:
black-sliver
2026-03-10 22:05:58 +01:00
5 changed files with 208 additions and 13 deletions

View File

@@ -4,6 +4,7 @@ import asyncio
import collections
import datetime
import functools
import itertools
import logging
import multiprocessing
import pickle
@@ -15,6 +16,7 @@ import typing
import sys
from asyncio import AbstractEventLoop
import psutil
import websockets
from pony.orm import commit, db_session, select
@@ -133,7 +135,7 @@ class WebHostContext(Context):
if room.last_port:
self.port = room.last_port
else:
self.port = get_random_port()
self.port = 0
multidata = self.decompress(room.seed.multidata)
return self._load(multidata, True)
@@ -171,8 +173,98 @@ class WebHostContext(Context):
return d
def get_random_port():
return random.randint(49152, 65535)
class GameRangePorts(typing.NamedTuple):
parsed_ports: list[range]
weights: list[int]
ephemeral_allowed: bool
@functools.cache
def parse_game_ports(game_ports: tuple[str | int, ...]) -> GameRangePorts:
parsed_ports: list[range] = []
weights: list[int] = []
ephemeral_allowed = False
total_length = 0
for item in game_ports:
if isinstance(item, str) and "-" in item:
start, end = map(int, item.split("-"))
x = range(start, end + 1)
total_length += len(x)
weights.append(total_length)
parsed_ports.append(x)
elif int(item) == 0:
ephemeral_allowed = True
else:
total_length += 1
weights.append(total_length)
num = int(item)
parsed_ports.append(range(num, num + 1))
return GameRangePorts(parsed_ports, weights, ephemeral_allowed)
def weighted_random(ranges: list[range], cum_weights: list[int]) -> int:
[picked] = random.choices(ranges, cum_weights=cum_weights)
return random.randrange(picked.start, picked.stop, picked.step)
def create_random_port_socket(game_ports: tuple[str | int, ...], host: str) -> socket.socket:
parsed_ports, weights, ephemeral_allowed = parse_game_ports(game_ports)
used_ports = get_used_ports()
i = 1024 if len(parsed_ports) > 0 else 0
while i > 0:
port_num = weighted_random(parsed_ports, weights)
if port_num in used_ports:
used_ports = get_used_ports()
continue
i -= 0
try:
return socket.create_server((host, port_num))
except OSError:
pass
if ephemeral_allowed:
return socket.create_server((host, 0))
raise OSError(98, "No available ports")
def try_conns_per_process(p: psutil.Process) -> typing.Iterable[int]:
try:
return (c.laddr.port for c in p.net_connections("tcp4"))
except psutil.AccessDenied:
return ()
def get_active_net_connections() -> typing.Iterable[int]:
# Don't even try to check if system using AIX
if psutil.AIX:
return ()
try:
return (c.laddr.port for c in psutil.net_connections("tcp4"))
# raises AccessDenied when done on macOS
except psutil.AccessDenied:
# flatten the list of iterables
return itertools.chain.from_iterable(map(
# get the net connections of the process and then map its ports
try_conns_per_process,
# this method has caching handled by psutil
psutil.process_iter(["net_connections"])
))
def get_used_ports():
last_used_ports: tuple[frozenset[int], float] | None = getattr(get_used_ports, "last", None)
t_hash = round(time.time() / 90) # cache for 90 seconds
if last_used_ports is None or last_used_ports[1] != t_hash:
last_used_ports = (frozenset(get_active_net_connections()), t_hash)
setattr(get_used_ports, "last", last_used_ports)
return last_used_ports[0]
class StaticServerData(typing.TypedDict, total=True):
@@ -231,6 +323,7 @@ def run_server_process(
cert_file: typing.Optional[str],
cert_key_file: typing.Optional[str],
host: str,
game_ports: typing.Iterable[str | int],
rooms_to_run: multiprocessing.Queue,
rooms_shutting_down: multiprocessing.Queue,
) -> None:
@@ -253,6 +346,8 @@ def run_server_process(
# prime the data package cache with static data
games_package_cache = DBGamesPackageCache(static_server_data["games_package"])
# convert to tuple because its hashable
game_ports = tuple(game_ports)
# establish DB connection for multidata and multisave
db.bind(**ponyconfig)
@@ -289,20 +384,26 @@ def run_server_process(
ctx.load(room_id)
ctx.init_save()
assert ctx.server is None
try:
if ctx.port != 0:
try:
ctx.server = websockets.serve(
functools.partial(server, ctx=ctx),
ctx.host,
ctx.port,
ssl=get_ssl_context(),
extensions=[server_per_message_deflate_factory],
)
await ctx.server
except OSError:
ctx.port = 0
if ctx.port == 0:
ctx.server = websockets.serve(
functools.partial(server, ctx=ctx),
ctx.host,
ctx.port,
sock=create_random_port_socket(game_ports, ctx.host),
ssl=get_ssl_context(),
extensions=[server_per_message_deflate_factory],
)
await ctx.server
except OSError: # likely port in use
ctx.server = websockets.serve(
functools.partial(server, ctx=ctx), ctx.host, 0, ssl=get_ssl_context())
await ctx.server
port = 0
for wssocket in ctx.server.ws_server.sockets:
socketname = wssocket.getsockname()
@@ -377,7 +478,7 @@ def run_server_process(
def run(self):
while 1:
next_room = rooms_to_run.get(block=True, timeout=None)
next_room = rooms_to_run.get(block=True, timeout=None)
gc.collect()
task = asyncio.run_coroutine_threadsafe(start_room(next_room), loop)
self._tasks.append(task)