mirror of
https://github.com/ArchipelagoMW/Archipelago.git
synced 2026-03-23 11:43:24 -07:00
Merge remote-tracking branch 'imurx/custom-port-range' into active/rc-site
This commit is contained in:
@@ -4,6 +4,7 @@ import asyncio
|
||||
import collections
|
||||
import datetime
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pickle
|
||||
@@ -15,6 +16,7 @@ import typing
|
||||
import sys
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
import psutil
|
||||
import websockets
|
||||
from pony.orm import commit, db_session, select
|
||||
|
||||
@@ -133,7 +135,7 @@ class WebHostContext(Context):
|
||||
if room.last_port:
|
||||
self.port = room.last_port
|
||||
else:
|
||||
self.port = get_random_port()
|
||||
self.port = 0
|
||||
|
||||
multidata = self.decompress(room.seed.multidata)
|
||||
return self._load(multidata, True)
|
||||
@@ -171,8 +173,98 @@ class WebHostContext(Context):
|
||||
return d
|
||||
|
||||
|
||||
def get_random_port():
|
||||
return random.randint(49152, 65535)
|
||||
class GameRangePorts(typing.NamedTuple):
|
||||
parsed_ports: list[range]
|
||||
weights: list[int]
|
||||
ephemeral_allowed: bool
|
||||
|
||||
|
||||
@functools.cache
|
||||
def parse_game_ports(game_ports: tuple[str | int, ...]) -> GameRangePorts:
|
||||
parsed_ports: list[range] = []
|
||||
weights: list[int] = []
|
||||
ephemeral_allowed = False
|
||||
total_length = 0
|
||||
|
||||
for item in game_ports:
|
||||
if isinstance(item, str) and "-" in item:
|
||||
start, end = map(int, item.split("-"))
|
||||
x = range(start, end + 1)
|
||||
total_length += len(x)
|
||||
weights.append(total_length)
|
||||
parsed_ports.append(x)
|
||||
elif int(item) == 0:
|
||||
ephemeral_allowed = True
|
||||
else:
|
||||
total_length += 1
|
||||
weights.append(total_length)
|
||||
num = int(item)
|
||||
parsed_ports.append(range(num, num + 1))
|
||||
|
||||
return GameRangePorts(parsed_ports, weights, ephemeral_allowed)
|
||||
|
||||
|
||||
def weighted_random(ranges: list[range], cum_weights: list[int]) -> int:
|
||||
[picked] = random.choices(ranges, cum_weights=cum_weights)
|
||||
return random.randrange(picked.start, picked.stop, picked.step)
|
||||
|
||||
|
||||
def create_random_port_socket(game_ports: tuple[str | int, ...], host: str) -> socket.socket:
|
||||
parsed_ports, weights, ephemeral_allowed = parse_game_ports(game_ports)
|
||||
used_ports = get_used_ports()
|
||||
i = 1024 if len(parsed_ports) > 0 else 0
|
||||
while i > 0:
|
||||
port_num = weighted_random(parsed_ports, weights)
|
||||
if port_num in used_ports:
|
||||
used_ports = get_used_ports()
|
||||
continue
|
||||
|
||||
i -= 0
|
||||
|
||||
try:
|
||||
return socket.create_server((host, port_num))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if ephemeral_allowed:
|
||||
return socket.create_server((host, 0))
|
||||
|
||||
raise OSError(98, "No available ports")
|
||||
|
||||
|
||||
def try_conns_per_process(p: psutil.Process) -> typing.Iterable[int]:
|
||||
try:
|
||||
return (c.laddr.port for c in p.net_connections("tcp4"))
|
||||
except psutil.AccessDenied:
|
||||
return ()
|
||||
|
||||
|
||||
def get_active_net_connections() -> typing.Iterable[int]:
|
||||
# Don't even try to check if system using AIX
|
||||
if psutil.AIX:
|
||||
return ()
|
||||
|
||||
try:
|
||||
return (c.laddr.port for c in psutil.net_connections("tcp4"))
|
||||
# raises AccessDenied when done on macOS
|
||||
except psutil.AccessDenied:
|
||||
# flatten the list of iterables
|
||||
return itertools.chain.from_iterable(map(
|
||||
# get the net connections of the process and then map its ports
|
||||
try_conns_per_process,
|
||||
# this method has caching handled by psutil
|
||||
psutil.process_iter(["net_connections"])
|
||||
))
|
||||
|
||||
|
||||
def get_used_ports():
|
||||
last_used_ports: tuple[frozenset[int], float] | None = getattr(get_used_ports, "last", None)
|
||||
t_hash = round(time.time() / 90) # cache for 90 seconds
|
||||
if last_used_ports is None or last_used_ports[1] != t_hash:
|
||||
last_used_ports = (frozenset(get_active_net_connections()), t_hash)
|
||||
setattr(get_used_ports, "last", last_used_ports)
|
||||
|
||||
return last_used_ports[0]
|
||||
|
||||
|
||||
class StaticServerData(typing.TypedDict, total=True):
|
||||
@@ -231,6 +323,7 @@ def run_server_process(
|
||||
cert_file: typing.Optional[str],
|
||||
cert_key_file: typing.Optional[str],
|
||||
host: str,
|
||||
game_ports: typing.Iterable[str | int],
|
||||
rooms_to_run: multiprocessing.Queue,
|
||||
rooms_shutting_down: multiprocessing.Queue,
|
||||
) -> None:
|
||||
@@ -253,6 +346,8 @@ def run_server_process(
|
||||
|
||||
# prime the data package cache with static data
|
||||
games_package_cache = DBGamesPackageCache(static_server_data["games_package"])
|
||||
# convert to tuple because its hashable
|
||||
game_ports = tuple(game_ports)
|
||||
|
||||
# establish DB connection for multidata and multisave
|
||||
db.bind(**ponyconfig)
|
||||
@@ -289,20 +384,26 @@ def run_server_process(
|
||||
ctx.load(room_id)
|
||||
ctx.init_save()
|
||||
assert ctx.server is None
|
||||
try:
|
||||
if ctx.port != 0:
|
||||
try:
|
||||
ctx.server = websockets.serve(
|
||||
functools.partial(server, ctx=ctx),
|
||||
ctx.host,
|
||||
ctx.port,
|
||||
ssl=get_ssl_context(),
|
||||
extensions=[server_per_message_deflate_factory],
|
||||
)
|
||||
await ctx.server
|
||||
except OSError:
|
||||
ctx.port = 0
|
||||
if ctx.port == 0:
|
||||
ctx.server = websockets.serve(
|
||||
functools.partial(server, ctx=ctx),
|
||||
ctx.host,
|
||||
ctx.port,
|
||||
sock=create_random_port_socket(game_ports, ctx.host),
|
||||
ssl=get_ssl_context(),
|
||||
extensions=[server_per_message_deflate_factory],
|
||||
)
|
||||
await ctx.server
|
||||
except OSError: # likely port in use
|
||||
ctx.server = websockets.serve(
|
||||
functools.partial(server, ctx=ctx), ctx.host, 0, ssl=get_ssl_context())
|
||||
|
||||
await ctx.server
|
||||
port = 0
|
||||
for wssocket in ctx.server.ws_server.sockets:
|
||||
socketname = wssocket.getsockname()
|
||||
@@ -377,7 +478,7 @@ def run_server_process(
|
||||
|
||||
def run(self):
|
||||
while 1:
|
||||
next_room = rooms_to_run.get(block=True, timeout=None)
|
||||
next_room = rooms_to_run.get(block=True, timeout=None)
|
||||
gc.collect()
|
||||
task = asyncio.run_coroutine_threadsafe(start_room(next_room), loop)
|
||||
self._tasks.append(task)
|
||||
|
||||
Reference in New Issue
Block a user