use weights for random port and remove more-itertools

This commit is contained in:
Uriel
2026-03-05 17:01:08 -03:00
parent 551dbf44f6
commit f03d1cad3e
2 changed files with 49 additions and 24 deletions

View File

@@ -8,13 +8,13 @@ import itertools
import logging
import multiprocessing
import pickle
import random
import socket
import threading
import time
import typing
import sys
from more_itertools import value_chain, random_permutation
import psutil
import websockets
from pony.orm import commit, db_session, select
@@ -184,24 +184,56 @@ class WebHostContext(Context):
return d
def get_random_port(game_ports: list[str | int], host: str):
available_ports = []
@functools.cache
def parse_game_ports(game_ports: tuple[str | int]):
available_ports: list[range | list[int]] = []
weights = []
ephemeral_allowed = False
total_length = 0
for item in game_ports:
if isinstance(item, str) and "-" in item:
start, end = map(int, item.split("-"))
available_ports.append(range(start, end+1))
x = range(start, end + 1)
total_length += len(x)
weights.append(total_length)
available_ports.append(x)
elif int(item) == 0:
ephemeral_allowed = True
else:
available_ports.append(int(item))
total_length += 1
weights.append(total_length)
available_ports.append([int(item)])
return get_port_from_list(
# limit amount of checked ports to 1024
random_permutation(
filter(lambda p: p not in get_used_ports(), value_chain(*available_ports)),
1024),
ephemeral_allowed, host)
return available_ports, weights, total_length, ephemeral_allowed
def get_random_port(game_ports: list[str | int], host: str) -> socket.socket:
# convert to tuple because its hashable
available_ports, weights, length, ephemeral_allowed = parse_game_ports(tuple(game_ports))
ports = random.choices(available_ports, cum_weights=weights, k=len(available_ports))
remaining = 1024
for r in ports:
r_length = len(r)
if isinstance(r, range):
random_range = itertools.islice(
filter(
lambda p: p not in get_used_ports(),
map(lambda _: random.randint(r.start, r.stop), range(r_length))
),
remaining)
port = get_port_from_list(random_range, host)
else:
port = get_port_from_list(filter(lambda p: p not in get_used_ports(), r), host)
remaining -= r_length
if port is not None: return port
if remaining <= 0: break
if ephemeral_allowed:
return socket.create_server((host, 0))
raise OSError(98, "No available ports")
_last_used_ports = (frozenset(map(lambda c: c.laddr.port, psutil.net_connections("tcp4"))), round(time.time() / 900))
@@ -214,20 +246,14 @@ def get_used_ports():
return _last_used_ports[0]
def get_port_from_list(available_ports: typing.Iterable[int], ephemeral_allowed: bool, host) -> socket.socket:
def get_port_from_list(available_ports: typing.Iterable[int], host: str) -> socket.socket | None:
for port in available_ports:
sock = get_socket_if_free(host, port)
if sock is not None: return sock
else:
if ephemeral_allowed: return socket.create_server((host, 0))
raise OSError(98, "No available ports")
try:
return socket.create_server((host, port))
except OSError:
_ = None
def get_socket_if_free(host, port: int) -> socket.socket | None:
try:
return socket.create_server((host, port))
except OSError:
return None
return None
@cache_argsless

View File

@@ -12,4 +12,3 @@ markupsafe>=3.0.2
setproctitle>=1.3.5
mistune>=3.1.3
docutils>=0.22.2
more-itertools>=10.8.0