diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index f968d412ed..dddad61d05 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -191,7 +191,7 @@ class GameRangePorts(typing.NamedTuple): @functools.cache -def parse_game_ports(game_ports: tuple[str | int]) -> GameRangePorts: +def parse_game_ports(game_ports: tuple[str | int, ...]) -> GameRangePorts: parsed_ports: list[range] = [] weights = [] ephemeral_allowed = False @@ -220,10 +220,10 @@ def weighted_random(ranges: list[range], cum_weights: list[int]) -> int: return random.randrange(picked.start, picked.stop, picked.step) -def create_random_port_socket(game_ports: tuple[str | int], host: str) -> socket.socket: +def create_random_port_socket(game_ports: tuple[str | int, ...], host: str) -> socket.socket: parsed_ports, weights, ephemeral_allowed = parse_game_ports(game_ports) used_ports = get_used_ports() - i = 1024 + i = 1024 if len(parsed_ports) > 0 else 0 while i > 0: port_num = weighted_random(parsed_ports, weights) if port_num in used_ports: diff --git a/test/webhost/test_port_allocation.py b/test/webhost/test_port_allocation.py new file mode 100644 index 0000000000..de22abbe7a --- /dev/null +++ b/test/webhost/test_port_allocation.py @@ -0,0 +1,100 @@ +import os +import statistics +import timeit +import unittest + +from WebHostLib.customserver import parse_game_ports, create_random_port_socket, get_used_ports + +ci = bool(os.environ.get("CI")) + + +class TestWebDescriptions(unittest.TestCase): + def test_parse_game_ports(self) -> None: + """Ensure that game ports with ranges are parsed correctly""" + val = parse_game_ports(("1000-2000", "2000-5000", "1000-2000", 20, 40, "20", "0")) + self.assertEqual(len(val.parsed_ports), 6, "Parsed port ranges is not the expected length") + self.assertEqual(len(val.weights), 6, "Parsed weights are not the expected length") + + self.assertEqual(val.parsed_ports[0], range(1000, 2001), "The first range wasn't parsed correctly") + self.assertEqual(val.parsed_ports[1], range(2000, 5001), "The second range wasn't parsed correctly") + self.assertEqual(val.parsed_ports[0], val.parsed_ports[2], + "The first and third range are not the same when they should be") + self.assertEqual(val.parsed_ports[3], range(20, 21), "The fourth range wasn't parsed correctly") + self.assertEqual(val.parsed_ports[4], range(40, 41), "The fifth range was not parsed correctly") + self.assertEqual(val.parsed_ports[3], val.parsed_ports[5], + "The fourth and last range are not the same when they should be") + + self.assertTrue(val.ephemeral_allowed, "The ephemeral allowed flag is not set even though it was passed") + + self.assertListEqual(val.weights, [1001, 4002, 5003, 5004, 5005, 5006], + "Cumulative weights are not the expected value") + + def test_parse_game_port_errors(self) -> None: + """Ensure that game ports with incorrect values raise the expected error""" + with self.assertRaises(ValueError, msg="Negative numbers didn't get interpreted as an invalid range"): + parse_game_ports(tuple("-50215")) + with self.assertRaises(ValueError, msg="Text got interpreted as a valid number"): + parse_game_ports(tuple("dwafawg")) + with self.assertRaises( + ValueError, + msg="A range with an extra dash at the end didn't get interpreted as an invalid number because of it's end dash" + ): + parse_game_ports(tuple("20-21215-")) + with self.assertRaises(ValueError, msg="Text got interpreted as a valid number for the start of a range"): + parse_game_ports(tuple("f-21215")) + + def test_random_port_socket_edge_cases(self) -> None: + # Try giving an empty tuple and fail over it + with self.assertRaises(OSError) as err: + create_random_port_socket(tuple(), "127.0.0.1") + self.assertEqual(err.exception.errno, 98, "Raised an unexpected error code") + self.assertEqual(err.exception.strerror, "No available ports", "Raised an unexpected error string") + + # Try only having ephemeral ports enabled + try: + create_random_port_socket(("0",), "127.0.0.1").close() + except OSError as err: + self.assertEqual(err.errno, 98, "Raised an unexpected error code") + # If it returns our error string that means something is wrong with our code + self.assertNotEqual(err.strerror, "No available ports", + "Raised an unexpected error string") + + # @unittest.skipUnless(ci, "can't guarantee free ports outside of CI") + def test_random_port_socket(self) -> None: + sockets = [] + for _ in range(6): + socket = create_random_port_socket(("8080-8085",), "127.0.0.1") + sockets.append(socket) + _, port = socket.getsockname() + self.assertIn(port, range(8080,8086), "Port of socket was not inside the expected range") + for s in sockets: + s.close() + + # Compared averages were calculated with a range of 100 in a Linux machine and then rounded up + sockets.clear() + time = [] + size = 65535 - (len(get_used_ports()) + 1024 + 4000) + for _ in range(10): + time.append(timeit.timeit(lambda: sockets.append( + create_random_port_socket(("1024-30000", "30001-65535"), "127.0.0.1") + ), number=size)) + + for s in sockets: + s.close() + + self.assertLess(statistics.fmean(time), 1.2, + f"Time took to allocate {size} ports consecutively is higher than expected") + + sockets.clear() + time.clear() + size = 65535 - (len(get_used_ports()) + 1024 + 5) + for _ in range(10): + time.append(timeit.timeit(lambda: sockets.append( + create_random_port_socket(("1024-30000", "30001-65535"), "127.0.0.1") + ), number=size)) + + for s in sockets: + s.close() + + self.assertLess(statistics.fmean(time), 5, + f"Time took to allocate {size} ports consecutively is higher than expected")