This commit is contained in:
CookieCat
2023-11-04 17:55:48 -04:00
parent 60df274157
commit bd8698e1fd
102 changed files with 1190 additions and 1463 deletions

View File

@@ -1,15 +1,14 @@
from __future__ import annotations
import copy
import itertools
import functools
import logging
import random
import secrets
import typing # this can go away when Python 3.8 support is dropped
from argparse import Namespace
from collections import Counter, deque
from collections.abc import Collection, MutableSequence
from collections import ChainMap, Counter, deque
from collections.abc import Collection
from enum import IntEnum, IntFlag
from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Set, Tuple, TypedDict, Union, \
Type, ClassVar
@@ -48,6 +47,7 @@ class ThreadBarrierProxy:
class MultiWorld():
debug_types = False
player_name: Dict[int, str]
_region_cache: Dict[int, Dict[str, Region]]
difficulty_requirements: dict
required_medallions: dict
dark_room_logic: Dict[int, str]
@@ -57,7 +57,7 @@ class MultiWorld():
plando_connections: List
worlds: Dict[int, auto_world]
groups: Dict[int, Group]
regions: RegionManager
regions: List[Region]
itempool: List[Item]
is_race: bool = False
precollected_items: Dict[int, List[Item]]
@@ -92,34 +92,6 @@ class MultiWorld():
def __getitem__(self, player) -> bool:
return self.rule(player)
class RegionManager:
region_cache: Dict[int, Dict[str, Region]]
entrance_cache: Dict[int, Dict[str, Entrance]]
location_cache: Dict[int, Dict[str, Location]]
def __init__(self, players: int):
self.region_cache = {player: {} for player in range(1, players+1)}
self.entrance_cache = {player: {} for player in range(1, players+1)}
self.location_cache = {player: {} for player in range(1, players+1)}
def __iadd__(self, other: Iterable[Region]):
self.extend(other)
return self
def append(self, region: Region):
self.region_cache[region.player][region.name] = region
def extend(self, regions: Iterable[Region]):
for region in regions:
self.region_cache[region.player][region.name] = region
def __iter__(self) -> Iterator[Region]:
for regions in self.region_cache.values():
yield from regions.values()
def __len__(self):
return sum(len(regions) for regions in self.region_cache.values())
def __init__(self, players: int):
# world-local random state is saved for multiple generations running concurrently
self.random = ThreadBarrierProxy(random.Random())
@@ -128,12 +100,16 @@ class MultiWorld():
self.glitch_triforce = False
self.algorithm = 'balanced'
self.groups = {}
self.regions = self.RegionManager(players)
self.regions = []
self.shops = []
self.itempool = []
self.seed = None
self.seed_name: str = "Unavailable"
self.precollected_items = {player: [] for player in self.player_ids}
self._cached_entrances = None
self._cached_locations = None
self._entrance_cache = {}
self._location_cache: Dict[Tuple[str, int], Location] = {}
self.required_locations = []
self.light_world_light_cone = False
self.dark_world_light_cone = False
@@ -161,6 +137,7 @@ class MultiWorld():
def set_player_attr(attr, val):
self.__dict__.setdefault(attr, {})[player] = val
set_player_attr('_region_cache', {})
set_player_attr('shuffle', "vanilla")
set_player_attr('logic', "noglitches")
set_player_attr('mode', 'open')
@@ -222,6 +199,7 @@ class MultiWorld():
self.game[new_id] = game
self.player_types[new_id] = NetUtils.SlotType.group
self._region_cache[new_id] = {}
world_type = AutoWorld.AutoWorldRegister.world_types[game]
self.worlds[new_id] = world_type.create_group(self, new_id, players)
self.worlds[new_id].collect_item = classmethod(AutoWorld.World.collect_item).__get__(self.worlds[new_id])
@@ -325,15 +303,11 @@ class MultiWorld():
def player_ids(self) -> Tuple[int, ...]:
return tuple(range(1, self.players + 1))
@Utils.cache_self1
@functools.lru_cache()
def get_game_players(self, game_name: str) -> Tuple[int, ...]:
return tuple(player for player in self.player_ids if self.game[player] == game_name)
@Utils.cache_self1
def get_game_groups(self, game_name: str) -> Tuple[int, ...]:
return tuple(group_id for group_id in self.groups if self.game[group_id] == game_name)
@Utils.cache_self1
@functools.lru_cache()
def get_game_worlds(self, game_name: str):
return tuple(world for player, world in self.worlds.items() if
player not in self.groups and self.game[player] == game_name)
@@ -355,17 +329,41 @@ class MultiWorld():
def world_name_lookup(self):
return {self.player_name[player_id]: player_id for player_id in self.player_ids}
def _recache(self):
"""Rebuild world cache"""
self._cached_locations = None
for region in self.regions:
player = region.player
self._region_cache[player][region.name] = region
for exit in region.exits:
self._entrance_cache[exit.name, player] = exit
for r_location in region.locations:
self._location_cache[r_location.name, player] = r_location
def get_regions(self, player: Optional[int] = None) -> Collection[Region]:
return self.regions if player is None else self.regions.region_cache[player].values()
return self.regions if player is None else self._region_cache[player].values()
def get_region(self, region_name: str, player: int) -> Region:
return self.regions.region_cache[player][region_name]
def get_region(self, regionname: str, player: int) -> Region:
try:
return self._region_cache[player][regionname]
except KeyError:
self._recache()
return self._region_cache[player][regionname]
def get_entrance(self, entrance_name: str, player: int) -> Entrance:
return self.regions.entrance_cache[player][entrance_name]
def get_entrance(self, entrance: str, player: int) -> Entrance:
try:
return self._entrance_cache[entrance, player]
except KeyError:
self._recache()
return self._entrance_cache[entrance, player]
def get_location(self, location_name: str, player: int) -> Location:
return self.regions.location_cache[player][location_name]
def get_location(self, location: str, player: int) -> Location:
try:
return self._location_cache[location, player]
except KeyError:
self._recache()
return self._location_cache[location, player]
def get_all_state(self, use_cache: bool) -> CollectionState:
cached = getattr(self, "_all_state", None)
@@ -426,22 +424,28 @@ class MultiWorld():
logging.debug('Placed %s at %s', item, location)
def get_entrances(self, player: Optional[int] = None) -> Iterable[Entrance]:
if player is not None:
return self.regions.entrance_cache[player].values()
return Utils.RepeatableChain(tuple(self.regions.entrance_cache[player].values()
for player in self.regions.entrance_cache))
def get_entrances(self) -> List[Entrance]:
if self._cached_entrances is None:
self._cached_entrances = [entrance for region in self.regions for entrance in region.entrances]
return self._cached_entrances
def clear_entrance_cache(self):
self._cached_entrances = None
def register_indirect_condition(self, region: Region, entrance: Entrance):
"""Report that access to this Region can result in unlocking this Entrance,
state.can_reach(Region) in the Entrance's traversal condition, as opposed to pure transition logic."""
self.indirect_connections.setdefault(region, set()).add(entrance)
def get_locations(self, player: Optional[int] = None) -> Iterable[Location]:
def get_locations(self, player: Optional[int] = None) -> List[Location]:
if self._cached_locations is None:
self._cached_locations = [location for region in self.regions for location in region.locations]
if player is not None:
return self.regions.location_cache[player].values()
return Utils.RepeatableChain(tuple(self.regions.location_cache[player].values()
for player in self.regions.location_cache))
return [location for location in self._cached_locations if location.player == player]
return self._cached_locations
def clear_location_cache(self):
self._cached_locations = None
def get_unfilled_locations(self, player: Optional[int] = None) -> List[Location]:
return [location for location in self.get_locations(player) if location.item is None]
@@ -463,17 +467,16 @@ class MultiWorld():
valid_locations = [location.name for location in self.get_unfilled_locations(player)]
else:
valid_locations = location_names
relevant_cache = self.regions.location_cache[player]
for location_name in valid_locations:
location = relevant_cache.get(location_name, None)
if location and location.item is None:
location = self._location_cache.get((location_name, player), None)
if location is not None and location.item is None:
yield location
def unlocks_new_location(self, item: Item) -> bool:
temp_state = self.state.copy()
temp_state.collect(item, True)
for location in self.get_unfilled_locations(item.player):
for location in self.get_unfilled_locations():
if temp_state.can_reach(location) and not self.state.can_reach(location):
return True
@@ -605,7 +608,7 @@ PathValue = Tuple[str, Optional["PathValue"]]
class CollectionState():
prog_items: Dict[int, Counter[str]]
prog_items: typing.Counter[Tuple[str, int]]
multiworld: MultiWorld
reachable_regions: Dict[int, Set[Region]]
blocked_connections: Dict[int, Set[Entrance]]
@@ -617,7 +620,7 @@ class CollectionState():
additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = []
def __init__(self, parent: MultiWorld):
self.prog_items = {player: Counter() for player in parent.player_ids}
self.prog_items = Counter()
self.multiworld = parent
self.reachable_regions = {player: set() for player in parent.get_all_ids()}
self.blocked_connections = {player: set() for player in parent.get_all_ids()}
@@ -665,7 +668,7 @@ class CollectionState():
def copy(self) -> CollectionState:
ret = CollectionState(self.multiworld)
ret.prog_items = copy.deepcopy(self.prog_items)
ret.prog_items = self.prog_items.copy()
ret.reachable_regions = {player: copy.copy(self.reachable_regions[player]) for player in
self.reachable_regions}
ret.blocked_connections = {player: copy.copy(self.blocked_connections[player]) for player in
@@ -709,23 +712,23 @@ class CollectionState():
self.collect(event.item, True, event)
def has(self, item: str, player: int, count: int = 1) -> bool:
return self.prog_items[player][item] >= count
return self.prog_items[item, player] >= count
def has_all(self, items: Set[str], player: int) -> bool:
"""Returns True if each item name of items is in state at least once."""
return all(self.prog_items[player][item] for item in items)
return all(self.prog_items[item, player] for item in items)
def has_any(self, items: Set[str], player: int) -> bool:
"""Returns True if at least one item name of items is in state at least once."""
return any(self.prog_items[player][item] for item in items)
return any(self.prog_items[item, player] for item in items)
def count(self, item: str, player: int) -> int:
return self.prog_items[player][item]
return self.prog_items[item, player]
def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool:
found: int = 0
for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]:
found += self.prog_items[player][item_name]
found += self.prog_items[item_name, player]
if found >= count:
return True
return False
@@ -733,11 +736,11 @@ class CollectionState():
def count_group(self, item_name_group: str, player: int) -> int:
found: int = 0
for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]:
found += self.prog_items[player][item_name]
found += self.prog_items[item_name, player]
return found
def item_count(self, item: str, player: int) -> int:
return self.prog_items[player][item]
return self.prog_items[item, player]
def collect(self, item: Item, event: bool = False, location: Optional[Location] = None) -> bool:
if location:
@@ -746,7 +749,7 @@ class CollectionState():
changed = self.multiworld.worlds[item.player].collect(self, item)
if not changed and event:
self.prog_items[item.player][item.name] += 1
self.prog_items[item.name, item.player] += 1
changed = True
self.stale[item.player] = True
@@ -813,83 +816,15 @@ class Region:
locations: List[Location]
entrance_type: ClassVar[Type[Entrance]] = Entrance
class Register(MutableSequence):
region_manager: MultiWorld.RegionManager
def __init__(self, region_manager: MultiWorld.RegionManager):
self._list = []
self.region_manager = region_manager
def __getitem__(self, index: int) -> Location:
return self._list.__getitem__(index)
def __setitem__(self, index: int, value: Location) -> None:
raise NotImplementedError()
def __len__(self) -> int:
return self._list.__len__()
# This seems to not be needed, but that's a bit suspicious.
# def __del__(self):
# self.clear()
def copy(self):
return self._list.copy()
class LocationRegister(Register):
def __delitem__(self, index: int) -> None:
location: Location = self._list.__getitem__(index)
self._list.__delitem__(index)
del(self.region_manager.location_cache[location.player][location.name])
def insert(self, index: int, value: Location) -> None:
self._list.insert(index, value)
self.region_manager.location_cache[value.player][value.name] = value
class EntranceRegister(Register):
def __delitem__(self, index: int) -> None:
entrance: Entrance = self._list.__getitem__(index)
self._list.__delitem__(index)
del(self.region_manager.entrance_cache[entrance.player][entrance.name])
def insert(self, index: int, value: Entrance) -> None:
self._list.insert(index, value)
self.region_manager.entrance_cache[value.player][value.name] = value
_locations: LocationRegister[Location]
_exits: EntranceRegister[Entrance]
def __init__(self, name: str, player: int, multiworld: MultiWorld, hint: Optional[str] = None):
self.name = name
self.entrances = []
self._exits = self.EntranceRegister(multiworld.regions)
self._locations = self.LocationRegister(multiworld.regions)
self.exits = []
self.locations = []
self.multiworld = multiworld
self._hint_text = hint
self.player = player
def get_locations(self):
return self._locations
def set_locations(self, new):
if new is self._locations:
return
self._locations.clear()
self._locations.extend(new)
locations = property(get_locations, set_locations)
def get_exits(self):
return self._exits
def set_exits(self, new):
if new is self._exits:
return
self._exits.clear()
self._exits.extend(new)
exits = property(get_exits, set_exits)
def can_reach(self, state: CollectionState) -> bool:
if state.stale[self.player]:
state.update_reachable_regions(self.player)
@@ -920,7 +855,7 @@ class Region:
self.locations.append(location_type(self.player, location, address, self))
def connect(self, connecting_region: Region, name: Optional[str] = None,
rule: Optional[Callable[[CollectionState], bool]] = None) -> entrance_type:
rule: Optional[Callable[[CollectionState], bool]] = None) -> None:
"""
Connects this Region to another Region, placing the provided rule on the connection.
@@ -931,7 +866,6 @@ class Region:
if rule:
exit_.access_rule = rule
exit_.connect(connecting_region)
return exit_
def create_exit(self, name: str) -> Entrance:
"""