From 286769a0f3eec5745d28f2dfc5d9fd47117d4baa Mon Sep 17 00:00:00 2001 From: Ian Robinson Date: Sun, 8 Feb 2026 11:00:23 -0500 Subject: [PATCH] Core: Add rule builder (#5048) * initial commit of rules engine * implement most of the stuff * add docs and fill out rest of the functionality * add in explain functions * dedupe items and add more docs * pr feedback and optimization updates * Self is not in typing on 3.10 * fix test * Update docs/rule builder.md Co-authored-by: BadMagic100 * pr feedback * love it when CI gives me different results than local * add composition with bitwise and and or * strongly typed option filtering * skip resolving location parent region * update docs * update typing and add decorator * add string explains * move simplify code to world * add wrapper rule * I may need to abandon the generic typing * missing space for faris * fix hashing for resolved rules * thank u typing extensions ilu * remove bad cacheable check * add decorator to assign hash and rule name * more type crimes... * region access rules are now cached * break compatibility so new features work * update docs * replace decorators with __init_subclass__ * ok now the frozen dataclass is automatic * one more type fix for the road * small fixes and caching tests * play nicer with tests * ok actually fix the tests * add item_mapping for faris * add more state helpers as rules * fix has from list rules * fix can reach location caching and add set completion condition * fix can reach entrance caching * implement HasGroup and HasGroupUnique * add more tests and fix some bugs * Add name arg to create_entrance Co-authored-by: roseasromeo <11944660+roseasromeo@users.noreply.github.com> * fix json dumping option filters * restructure and test serialization * add prop to disable caching * switch to __call__ and revert access_rule changes * update docs and make edge cases match * ruff has lured me into a false sense of security * also unused * fix disabling caching * move filter function to filter class * add more docs * tests for explain functions * Update docs/rule builder.md Co-authored-by: roseasromeo <11944660+roseasromeo@users.noreply.github.com> * chore: Strip out uses of TYPE_CHECKING as much as possible * chore: add empty webworld for test * chore: optimize rule evaluations * remove getattr from hot code paths * testing new cache flags * only clear cache for rules cached as false in collect * update test for new behaviour * do not have rules inherit from each other * update docs on caching * fix name of attribute * make explain messages more colorful * fix issue with combining rules with different options * add convenience functions for filtering * use an operator with higher precedence * name conflicts less with optionfilter * move simplify and instance caching code * update docs * kill resolve_rule * kill true_rule and false_rule * move helpers to base classes * update docs * I really should finish all of my * fix test * rename mixin * fix typos * refactor rule builder into folder for better imports * update docs * do not dupe collectionrule * docs review feedback * missed a file * remove rule_caching_enabled from base World * update docs on caching * shuffle around some docs * use option instead of option.value * add in operator and more testing * rm World = object * test fixes * move cache to logic mixin * keep test rule builder world out of global registry * todone * call register_dependencies automatically * move register deps call to call_single * add filtered_resolution * allow bool opts on filters * fix serialization tests * allow reverse operations --------- Co-authored-by: BadMagic100 Co-authored-by: roseasromeo <11944660+roseasromeo@users.noreply.github.com> --- .github/pyright-config.json | 4 + BaseClasses.py | 36 +- docs/rule builder.md | 482 ++++++++ rule_builder/__init__.py | 0 rule_builder/cached_world.py | 146 +++ rule_builder/options.py | 91 ++ rule_builder/rules.py | 1791 +++++++++++++++++++++++++++++ test/general/test_rule_builder.py | 1336 +++++++++++++++++++++ worlds/AutoWorld.py | 78 +- worlds/generic/Rules.py | 30 +- 10 files changed, 3956 insertions(+), 38 deletions(-) create mode 100644 docs/rule builder.md create mode 100644 rule_builder/__init__.py create mode 100644 rule_builder/cached_world.py create mode 100644 rule_builder/options.py create mode 100644 rule_builder/rules.py create mode 100644 test/general/test_rule_builder.py diff --git a/.github/pyright-config.json b/.github/pyright-config.json index 64a46d80cc..fba044da06 100644 --- a/.github/pyright-config.json +++ b/.github/pyright-config.json @@ -2,11 +2,15 @@ "include": [ "../BizHawkClient.py", "../Patch.py", + "../rule_builder/cached_world.py", + "../rule_builder/options.py", + "../rule_builder/rules.py", "../test/param.py", "../test/general/test_groups.py", "../test/general/test_helpers.py", "../test/general/test_memory.py", "../test/general/test_names.py", + "../test/general/test_rule_builder.py", "../test/multiworld/__init__.py", "../test/multiworld/test_multiworlds.py", "../test/netutils/__init__.py", diff --git a/BaseClasses.py b/BaseClasses.py index 4d88fde4f3..75036fc525 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -8,10 +8,10 @@ import secrets import warnings from argparse import Namespace from collections import Counter, deque, defaultdict -from collections.abc import Collection, MutableSequence +from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, MutableSequence, Set from enum import IntEnum, IntFlag -from typing import (AbstractSet, Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Literal, Mapping, NamedTuple, - Optional, Protocol, Set, Tuple, Union, TYPE_CHECKING, Literal, overload) +from typing import (AbstractSet, Any, ClassVar, Dict, List, Literal, NamedTuple, + Optional, Protocol, Tuple, Union, TYPE_CHECKING, overload) import dataclasses from typing_extensions import NotRequired, TypedDict @@ -85,7 +85,7 @@ class MultiWorld(): local_items: Dict[int, Options.LocalItems] non_local_items: Dict[int, Options.NonLocalItems] progression_balancing: Dict[int, Options.ProgressionBalancing] - completion_condition: Dict[int, Callable[[CollectionState], bool]] + completion_condition: Dict[int, CollectionRule] indirect_connections: Dict[Region, Set[Entrance]] exclude_locations: Dict[int, Options.ExcludeLocations] priority_locations: Dict[int, Options.PriorityLocations] @@ -766,7 +766,7 @@ class CollectionState(): else: self._update_reachable_regions_auto_indirect_conditions(player, queue) - def _update_reachable_regions_explicit_indirect_conditions(self, player: int, queue: deque): + def _update_reachable_regions_explicit_indirect_conditions(self, player: int, queue: deque[Entrance]): reachable_regions = self.reachable_regions[player] blocked_connections = self.blocked_connections[player] # run BFS on all connections, and keep track of those blocked by missing items @@ -784,13 +784,14 @@ class CollectionState(): blocked_connections.update(new_region.exits) queue.extend(new_region.exits) self.path[new_region] = (new_region.name, self.path.get(connection, None)) + self.multiworld.worlds[player].reached_region(self, new_region) # Retry connections if the new region can unblock them for new_entrance in self.multiworld.indirect_connections.get(new_region, set()): if new_entrance in blocked_connections and new_entrance not in queue: queue.append(new_entrance) - def _update_reachable_regions_auto_indirect_conditions(self, player: int, queue: deque): + def _update_reachable_regions_auto_indirect_conditions(self, player: int, queue: deque[Entrance]): reachable_regions = self.reachable_regions[player] blocked_connections = self.blocked_connections[player] new_connection: bool = True @@ -812,6 +813,7 @@ class CollectionState(): queue.extend(new_region.exits) self.path[new_region] = (new_region.name, self.path.get(connection, None)) new_connection = True + self.multiworld.worlds[player].reached_region(self, new_region) # sweep for indirect connections, mostly Entrance.can_reach(unrelated_Region) queue.extend(blocked_connections) @@ -1169,13 +1171,17 @@ class CollectionState(): self.prog_items[player][item] = count +CollectionRule = Callable[[CollectionState], bool] +DEFAULT_COLLECTION_RULE: CollectionRule = staticmethod(lambda state: True) + + class EntranceType(IntEnum): ONE_WAY = 1 TWO_WAY = 2 class Entrance: - access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True) + access_rule: CollectionRule = DEFAULT_COLLECTION_RULE hide_path: bool = False player: int name: str @@ -1362,7 +1368,7 @@ class Region: self, location_name: str, item_name: str | None = None, - rule: Callable[[CollectionState], bool] | None = None, + rule: CollectionRule | None = None, location_type: type[Location] | None = None, item_type: type[Item] | None = None, show_in_spoiler: bool = True, @@ -1401,7 +1407,7 @@ class Region: return event_item def connect(self, connecting_region: Region, name: Optional[str] = None, - rule: Optional[Callable[[CollectionState], bool]] = None) -> Entrance: + rule: Optional[CollectionRule] = None) -> Entrance: """ Connects this Region to another Region, placing the provided rule on the connection. @@ -1435,7 +1441,7 @@ class Region: return entrance def add_exits(self, exits: Iterable[str] | Mapping[str, str | None], - rules: Mapping[str, Callable[[CollectionState], bool]] | None = None) -> List[Entrance]: + rules: Mapping[str, CollectionRule] | None = None) -> List[Entrance]: """ Connects current region to regions in exit dictionary. Passed region names must exist first. @@ -1474,7 +1480,7 @@ class Location: show_in_spoiler: bool = True progress_type: LocationProgressType = LocationProgressType.DEFAULT always_allow: Callable[[CollectionState, Item], bool] = staticmethod(lambda state, item: False) - access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True) + access_rule: CollectionRule = DEFAULT_COLLECTION_RULE item_rule: Callable[[Item], bool] = staticmethod(lambda item: True) item: Optional[Item] = None @@ -1551,7 +1557,7 @@ class ItemClassification(IntFlag): skip_balancing = 0b01000 """ should technically never occur on its own Item that is logically relevant, but progression balancing should not touch. - + Possible reasons for why an item should not be pulled ahead by progression balancing: 1. This item is quite insignificant, so pulling it earlier doesn't help (currency/etc.) 2. It is important for the player experience that this item is evenly distributed in the seed (e.g. goal items) """ @@ -1559,13 +1565,13 @@ class ItemClassification(IntFlag): deprioritized = 0b10000 """ Should technically never occur on its own. Will not be considered for priority locations, - unless Priority Locations Fill runs out of regular progression items before filling all priority locations. - + unless Priority Locations Fill runs out of regular progression items before filling all priority locations. + Should be used for items that would feel bad for the player to find on a priority location. Usually, these are items that are plentiful or insignificant. """ progression_deprioritized_skip_balancing = 0b11001 - """ Since a common case of both skip_balancing and deprioritized is "insignificant progression", + """ Since a common case of both skip_balancing and deprioritized is "insignificant progression", these items often want both flags. """ progression_skip_balancing = 0b01001 # only progression gets balanced diff --git a/docs/rule builder.md b/docs/rule builder.md new file mode 100644 index 0000000000..4f9102a2ba --- /dev/null +++ b/docs/rule builder.md @@ -0,0 +1,482 @@ +# Rule Builder + +This document describes the API provided for the rule builder. Using this API provides you with with a simple interface to define rules and the following advantages: + +- Rule classes that avoid all the common pitfalls +- Logic optimization +- Automatic result caching (opt-in) +- Serialization/deserialization +- Human-readable logic explanations for players + +## Overview + +The rule builder consists of 3 main parts: + +1. The rules, which are classes that inherit from `rule_builder.rules.Rule`. These are what you write for your logic. They can be combined and take into account your world's options. There are a number of default rules listed below, and you can create as many custom rules for your world as needed. When assigning the rules to a location or entrance they must be resolved. +1. Resolved rules, which are classes that inherit from `rule_builder.rules.Rule.Resolved`. These are the optimized rules specific to one player that are set as a location or entrance's access rule. You generally shouldn't be directly creating these but they'll be created when assigning rules to locations or entrances. These are what power the human-readable logic explanations. +1. The optional rule builder world subclass `CachedRuleBuilderWorld`, which is a class your world can inherit from instead of `World`. It adds a caching system to the rules that will lazy evaluate and cache the result. + +## Usage + +For the most part the only difference in usage is instead of writing lambdas for your logic, you write static Rule objects. You then must use `world.set_rule` to assign the rule to a location or entrance. + +```python +# In your world's create_regions method +location = MyWorldLocation(...) +self.set_rule(location, Has("A Big Gun")) +``` + +The rule builder comes with a number of rules by default: + +- `True_`: Always returns true +- `False_`: Always returns false +- `And`: Checks that all child rules are true (also provided by `&` operator) +- `Or`: Checks that at least one child rule is true (also provided by `|` operator) +- `Has`: Checks that the player has the given item with the given count (default 1) +- `HasAll`: Checks that the player has all given items +- `HasAny`: Checks that the player has at least one of the given items +- `HasAllCounts`: Checks that the player has all of the counts for the given items +- `HasAnyCount`: Checks that the player has any of the counts for the given items +- `HasFromList`: Checks that the player has some number of given items +- `HasFromListUnique`: Checks that the player has some number of given items, ignoring duplicates of the same item +- `HasGroup`: Checks that the player has some number of items from a given item group +- `HasGroupUnique`: Checks that the player has some number of items from a given item group, ignoring duplicates of the same item +- `CanReachLocation`: Checks that the player can logically reach the given location +- `CanReachRegion`: Checks that the player can logically reach the given region +- `CanReachEntrance`: Checks that the player can logically reach the given entrance + +You can combine these rules together to describe the logic required for something. For example, to check if a player either has `Movement ability` or they have both `Key 1` and `Key 2`, you can do: + +```python +rule = Has("Movement ability") | HasAll("Key 1", "Key 2") +``` + +> ⚠️ Composing rules with the `and` and `or` keywords will not work. You must use the bitwise `&` and `|` operators. In order to catch mistakes, the rule builder will not let you do boolean operations. As a consequence, in order to check if a rule is defined you must use `if rule is not None`. + +### Assigning rules + +When assigning the rule you must use the `set_rule` helper to correctly resolve and register the rule. + +```python +self.set_rule(location_or_entrance, rule) +``` + +There is also a `create_entrance` helper that will resolve the rule, check if it's `False`, and if not create the entrance and set the rule. This allows you to skip creating entrances that will never be valid. You can also specify `force_creation=True` if you would like to create the entrance even if the rule is `False`. + +```python +self.create_entrance(from_region, to_region, rule) +``` + +> ⚠️ If you use a `CanReachLocation` rule on an entrance, you will either have to create the locations first, or specify the location's parent region name with the `parent_region_name` argument of `CanReachLocation`. + +You can also set a rule for your world's completion condition: + +```python +self.set_completion_rule(rule) +``` + +### Restricting options + +Every rule allows you to specify which options it's applicable for. You can provide the argument `options` which is an iterable of `OptionFilter` instances. Rules that pass the options check will be resolved as normal, and those that fail will be resolved as `False`. + +If you want a comparison that isn't equals, you can specify with the `operator` argument. The following operators are allowed: + +- `eq`: `==` +- `ne`: `!=` +- `gt`: `>` +- `lt`: `<` +- `ge`: `>=` +- `le`: `<=` +- `contains`: `in` + +By default rules that are excluded by their options will default to `False`. If you want to default to `True` instead, you can specify `filtered_resolution=True` on your rule. + +To check if the player can reach a switch, or if they've received the switch item if switches are randomized: + +```python +rule = ( + Has("Red switch", options=[OptionFilter(SwitchRando, 1)]) + | CanReachLocation("Red switch", options=[OptionFilter(SwitchRando, 0)]) +) +``` + +To add an extra logic requirement on the easiest difficulty which is ignored for other difficulties: + +```python +rule = ( + # ...the rest of the logic + & Has("QoL item", options=[OptionFilter(Difficulty, Difficulty.option_easy)], filtered_resolution=True) +) +``` + +If you would like to provide option filters when reusing or composing rules, you can use the `Filtered` helper rule: + +```python +common_rule = Has("A") | HasAny("B", "C") +... +rule = ( + Filtered(common_rule, options=[OptionFilter(Opt, 0)]), + | Filtered(Has("X") | CanReachRegion("Y"), options=[OptionFilter(Opt, 1)]), +) +``` + +You can also use the & and | operators to apply options to rules: + +```python +common_rule = Has("A") +easy_filter = [OptionFilter(Difficulty, Difficulty.option_easy)] +common_rule_only_on_easy = common_rule & easy_filter +common_rule_skipped_on_easy = common_rule | easy_filter +``` + +## Enabling caching + +The rule builder provides a `CachedRuleBuilderWorld` base class for your `World` class that enables caching on your rules. + +```python +class MyWorld(CachedRuleBuilderWorld): + game = "My Game" +``` + +If your world's logic is very simple and you don't have many nested rules, the caching system may have more overhead cost than time it saves. You'll have to benchmark your own world to see if it should be enabled or not. + +### Item name mapping + +If you have multiple real items that map to a single logic item, add a `item_mapping` class dict to your world that maps actual item names to real item names so the cache system knows what to invalidate. + +For example, if you have multiple `Currency x` items on locations, but your rules only check a singular logical `Currency` item, eg `Has("Currency", 1000)`, you'll want to map each numerical currency item to the single logical `Currency`. + +```python +class MyWorld(CachedRuleBuilderWorld): + item_mapping = { + "Currency x10": "Currency", + "Currency x50": "Currency", + "Currency x100": "Currency", + "Currency x500": "Currency", + } +``` + +## Defining custom rules + +You can create a custom rule by creating a class that inherits from `Rule` or any of the default rules. You must provide the game name as an argument to the class. It's recommended to use the `@dataclass` decorator to reduce boilerplate, and to also provide your world as a type argument to add correct type checking to the `_instantiate` method. + +You must provide or inherit a `Resolved` child class that defines an `_evaluate` method. This class will automatically be converted into a frozen `dataclass`. If your world has caching enabled you may need to define one or more dependencies functions as outlined below. + +To add a rule that checks if the user has enough mcguffins to goal, with a randomized requirement: + +```python +@dataclasses.dataclass() +class CanGoal(Rule["MyWorld"], game="My Game"): + @override + def _instantiate(self, world: "MyWorld") -> Rule.Resolved: + # caching_enabled only needs to be passed in when your world inherits from CachedRuleBuilderWorld + return self.Resolved(world.required_mcguffins, player=world.player, caching_enabled=True) + + class Resolved(Rule.Resolved): + goal: int + + @override + def _evaluate(self, state: CollectionState) -> bool: + return state.has("McGuffin", self.player, count=self.goal) + + @override + def item_dependencies(self) -> dict[str, set[int]]: + # this function is only required if you have caching enabled + return {"McGuffin": {id(self)}} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + # this method can be overridden to display custom explanations + return [ + {"type": "text", "text": "Goal with "}, + {"type": "color", "color": "green" if state and self(state) else "salmon", "text": str(self.goal)}, + {"type": "text", "text": " McGuffins"}, + ] +``` + +Your custom rule can also resolve to builtin rules instead of needing to define your own: + +```python +@dataclasses.dataclass() +class ComplicatedFilter(Rule["MyWorld"], game="My Game"): + def _instantiate(self, world: "MyWorld") -> Rule.Resolved: + if world.some_precalculated_bool: + return Has("Item 1").resolve(world) + if world.options.some_option: + return CanReachRegion("Region 1").resolve(world) + return False_().resolve(world) +``` + +### Item dependencies + +If your world inherits from `CachedRuleBuilderWorld` and there are items that when collected will affect the result of your rule evaluation, it must define an `item_dependencies` function that returns a mapping of the item name to the id of your rule. These dependencies will be combined to inform the caching system. It may be worthwhile to define this function even when caching is disabled as more things may use it in the future. + +```python +@dataclasses.dataclass() +class MyRule(Rule["MyWorld"], game="My Game"): + class Resolved(Rule.Resolved): + item_name: str + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {self.item_name: {id(self)}} +``` + +All of the default `Has*` rules define this function already. + +### Region dependencies + +If your custom rule references other regions, it must define a `region_dependencies` function that returns a mapping of region names to the id of your rule regardless of if your world inherits from `CachedRuleBuilderWorld`. These dependencies will be combined to register indirect connections when you set this rule on an entrance and inform the caching system if applicable. + +```python +@dataclasses.dataclass() +class MyRule(Rule["MyWorld"], game="My Game"): + class Resolved(Rule.Resolved): + region_name: str + + @override + def region_dependencies(self) -> dict[str, set[int]]: + return {self.region_name: {id(self)}} +``` + +The default `CanReachLocation`, `CanReachRegion`, and `CanReachEntrance` rules define this function already. + +### Location dependencies + +If your custom rule references other locations, it must define a `location_dependencies` function that returns a mapping of the location name to the id of your rule regardless of if your world inherits from `CachedRuleBuilderWorld`. These dependencies will be combined to register indirect connections when you set this rule on an entrance and inform the caching system if applicable. + +```python +@dataclasses.dataclass() +class MyRule(Rule["MyWorld"], game="My Game"): + class Resolved(Rule.Resolved): + location_name: str + + @override + def location_dependencies(self) -> dict[str, set[int]]: + return {self.location_name: {id(self)}} +``` + +The default `CanReachLocation` rule defines this function already. + +### Entrance dependencies + +If your custom rule references other entrances, it must define a `entrance_dependencies` function that returns a mapping of the entrance name to the id of your rule regardless of if your world inherits from `CachedRuleBuilderWorld`. These dependencies will be combined to register indirect connections when you set this rule on an entrance and inform the caching system if applicable. + +```python +@dataclasses.dataclass() +class MyRule(Rule["MyWorld"], game="My Game"): + class Resolved(Rule.Resolved): + entrance_name: str + + @override + def entrance_dependencies(self) -> dict[str, set[int]]: + return {self.entrance_name: {id(self)}} +``` + +The default `CanReachEntrance` rule defines this function already. + +### Rule explanations + +Resolved rules have a default implementation for `explain_json` and `explain_str` functions. The former optionally accepts a `CollectionState` and returns a list of `JSONMessagePart` appropriate for `print_json` in a client. It will display a human-readable message that explains what the rule requires. The latter is similar but returns a string. It is useful when debugging. There is also a `__str__` method defined to check what a rule is without a state. + +To implement a custom message with a custom rule, override the `explain_json` and/or `explain_str` method on your `Resolved` class: + +```python +class MyRule(Rule, game="My Game"): + class Resolved(Rule.Resolved): + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + has_item = state and state.has("growth spurt", self.player) + color = "yellow" + start = "You must be " + if has_item: + start = "You are " + color = "green" + elif state is not None: + start = "You are not " + color = "salmon" + return [ + {"type": "text", "text": start}, + {"type": "color", "color": color, "text": "THIS"}, + {"type": "text", "text": " tall to beat the game"}, + ] + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + if state.has("growth spurt", self.player): + return "You ARE this tall and can beat the game" + return "You are not THIS tall and cannot beat the game" + + @override + def __str__(self) -> str: + return "You must be THIS tall to beat the game" +``` + +### Cache control + +By default your custom rule will work through the cache system as any other rule if caching is enabled. There are two class attributes on the `Resolved` class you can override to change this behavior. + +- `force_recalculate`: Setting this to `True` will cause your custom rule to skip going through the caching system and always recalculate when being evaluated. When a rule with this flag enabled is composed with `And` or `Or` it will cause any parent rules to always force recalculate as well. Use this flag when it's difficult to determine when your rule should be marked as stale. +- `skip_cache`: Setting this to `True` will also cause your custom rule to skip going through the caching system when being evaluated. However, it will **not** affect any other rules when composed with `And` or `Or`, so it must still define its `*_dependencies` functions as required. Use this flag when the evaluation of this rule is trivial and the overhead of the caching system will slow it down. + +### Caveats + +- Ensure you are passing `caching_enabled=True` in your `_instantiate` function when creating resolved rule instances if your world has opted into caching. +- Resolved rules are forced to be frozen dataclasses. They and all their attributes must be immutable and hashable. +- If your rule creates child rules ensure they are being resolved through the world rather than creating `Resolved` instances directly. + +## Serialization + +The rule builder is intended to be written first in Python for optimization and type safety. To facilitate exporting the rules to a client or tracker, rules have a `to_dict` method that returns a JSON-compatible dict. Since the location and entrance logic structure varies greatly from world to world, the actual JSON dumping is left up to the world dev. + +The dict contains a `rule` key with the name of the rule, an `options` key with the rule's list of option filters, and an `args` key that contains any other arguments the individual rule has. For example, this is what a simple `Has` rule would look like: + +```python +{ + "rule": "Has", + "options": [], + "args": { + "item_name": "Some item", + "count": 1, + }, +} +``` + +For `And` and `Or` rules, instead of an `args` key, they have a `children` key containing a list of their child rules in the same serializable format: + +```python +{ + "rule": "And", + "options": [], + "children": [ + ..., # each serialized rule + ] +} +``` + +A full example is as follows: + +```python +rule = And( + Has("a", options=[OptionFilter(ToggleOption, 0)]), + Or(Has("b", count=2), CanReachRegion("c"), options=[OptionFilter(ToggleOption, 1)]), +) +assert rule.to_dict() == { + "rule": "And", + "options": [], + "children": [ + { + "rule": "Has", + "options": [ + { + "option": "worlds.my_world.options.ToggleOption", + "value": 0, + "operator": "eq", + }, + ], + "args": { + "item_name": "a", + "count": 1, + }, + }, + { + "rule": "Or", + "options": [ + { + "option": "worlds.my_world.options.ToggleOption", + "value": 1, + "operator": "eq", + }, + ], + "children": [ + { + "rule": "Has", + "options": [], + "args": { + "item_name": "b", + "count": 2, + }, + }, + { + "rule": "CanReachRegion", + "options": [], + "args": { + "region_name": "c", + }, + }, + ], + }, + ], +} +``` + +### Custom serialization + +To define a different format for your custom rules, override the `to_dict` function: + +```python +class BasicLogicRule(Rule, game="My Game"): + items = ("one", "two") + + def to_dict(self) -> dict[str, Any]: + # Return whatever format works best for you + return { + "logic": "basic", + "items": self.items, + } +``` + +If your logic has been done in custom JSON first, you can define a `from_dict` class method on your rules to parse it correctly: + +```python +class BasicLogicRule(Rule, game="My Game"): + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: type[World]) -> Self: + items = data.get("items", ()) + return cls(*items) +``` + +## APIs + +This section is provided for reference, refer to the above sections for examples. + +### World API + +These are properties and helpers that are available to you in your world. + +#### Methods + +- `rule_from_dict(data)`: Create a rule instance from a deserialized dict representation +- `register_rule_builder_dependencies()`: Register all rules that depend on location or entrance access with the inherited dependencies, gets called automatically after set_rules +- `set_rule(spot: Location | Entrance, rule: Rule)`: Resolve a rule, register its dependencies, and set it on the given location or entrance +- `set_completion_rule(rule: Rule)`: Sets the completion condition for this world +- `create_entrance(from_region: Region, to_region: Region, rule: Rule | None, name: str | None = None, force_creation: bool = False)`: Attempt to create an entrance from `from_region` to `to_region`, skipping creation if `rule` is defined and evaluates to `False_()` unless force_creation is `True` + +#### CachedRuleBuilderWorld Properties + +The following property is only available when inheriting from `CachedRuleBuilderWorld` + +- `item_mapping: dict[str, str]`: A mapping of actual item name to logical item name + +### Rule API + +These are properties and helpers that you can use or override for custom rules. + +- `_instantiate(world: World)`: Create a new resolved rule instance, override for custom rules as required +- `to_dict()`: Create a JSON-compatible dict representation of this rule, override if you want to customize your rule's serialization +- `from_dict(data, world_cls: type[World])`: Return a new rule instance from a deserialized representation, override if you've overridden `to_dict` +- `__str__()`: Basic string representation of a rule, useful for debugging + +#### Resolved rule API + +- `player: int`: The slot this rule is resolved for +- `_evaluate(state: CollectionState)`: Evaluate this rule against the given state, override this to define the logic for this rule +- `item_dependencies()`: A mapping of item name to set of ids, override this if your custom rule depends on item collection +- `region_dependencies()`: A mapping of region name to set of ids, override this if your custom rule depends on reaching regions +- `location_dependencies()`: A mapping of location name to set of ids, override this if your custom rule depends on reaching locations +- `entrance_dependencies()`: A mapping of entrance name to set of ids, override this if your custom rule depends on reaching entrances +- `explain_json(state: CollectionState | None = None)`: Return a list of printJSON messages describing this rule's logic (and if state is defined its evaluation) in a human readable way, override to explain custom rules +- `explain_str(state: CollectionState | None = None)`: Return a string describing this rule's logic (and if state is defined its evaluation) in a human readable way, override to explain custom rules, more useful for debugging +- `__str__()`: A string describing this rule's logic without its evaluation, override to explain custom rules diff --git a/rule_builder/__init__.py b/rule_builder/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rule_builder/cached_world.py b/rule_builder/cached_world.py new file mode 100644 index 0000000000..bb7cc4d9b5 --- /dev/null +++ b/rule_builder/cached_world.py @@ -0,0 +1,146 @@ +from collections import defaultdict +from typing import ClassVar, cast + +from typing_extensions import override + +from BaseClasses import CollectionState, Item, MultiWorld, Region +from worlds.AutoWorld import LogicMixin, World + +from .rules import Rule + + +class CachedRuleBuilderWorld(World): + """A World subclass that provides helpers for interacting with the rule builder""" + + rule_item_dependencies: dict[str, set[int]] + """A mapping of item name to set of rule ids""" + + rule_region_dependencies: dict[str, set[int]] + """A mapping of region name to set of rule ids""" + + rule_location_dependencies: dict[str, set[int]] + """A mapping of location name to set of rule ids""" + + rule_entrance_dependencies: dict[str, set[int]] + """A mapping of entrance name to set of rule ids""" + + item_mapping: ClassVar[dict[str, str]] = {} + """A mapping of actual item name to logical item name. + Useful when there are multiple versions of a collected item but the logic only uses one. For example: + item = Item("Currency x500"), rule = Has("Currency", count=1000), item_mapping = {"Currency x500": "Currency"}""" + + rule_caching_enabled: ClassVar[bool] = True + """Flag to inform rules that the caching system for this world is enabled. It should not be overridden.""" + + def __init__(self, multiworld: MultiWorld, player: int) -> None: + super().__init__(multiworld, player) + self.rule_item_dependencies = defaultdict(set) + self.rule_region_dependencies = defaultdict(set) + self.rule_location_dependencies = defaultdict(set) + self.rule_entrance_dependencies = defaultdict(set) + + @override + def register_rule_dependencies(self, resolved_rule: Rule.Resolved) -> None: + for item_name, rule_ids in resolved_rule.item_dependencies().items(): + self.rule_item_dependencies[item_name] |= rule_ids + for region_name, rule_ids in resolved_rule.region_dependencies().items(): + self.rule_region_dependencies[region_name] |= rule_ids + for location_name, rule_ids in resolved_rule.location_dependencies().items(): + self.rule_location_dependencies[location_name] |= rule_ids + for entrance_name, rule_ids in resolved_rule.entrance_dependencies().items(): + self.rule_entrance_dependencies[entrance_name] |= rule_ids + + def register_rule_builder_dependencies(self) -> None: + """Register all rules that depend on locations or entrances with their dependencies""" + for location_name, rule_ids in self.rule_location_dependencies.items(): + try: + location = self.get_location(location_name) + except KeyError: + continue + if not isinstance(location.access_rule, Rule.Resolved): + continue + for item_name in location.access_rule.item_dependencies(): + self.rule_item_dependencies[item_name] |= rule_ids + for region_name in location.access_rule.region_dependencies(): + self.rule_region_dependencies[region_name] |= rule_ids + + for entrance_name, rule_ids in self.rule_entrance_dependencies.items(): + try: + entrance = self.get_entrance(entrance_name) + except KeyError: + continue + if not isinstance(entrance.access_rule, Rule.Resolved): + continue + for item_name in entrance.access_rule.item_dependencies(): + self.rule_item_dependencies[item_name] |= rule_ids + for region_name in entrance.access_rule.region_dependencies(): + self.rule_region_dependencies[region_name] |= rule_ids + + @override + def collect(self, state: CollectionState, item: Item) -> bool: + changed = super().collect(state, item) + if changed and self.rule_item_dependencies: + player_results = cast(dict[int, bool], state.rule_builder_cache[self.player]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + mapped_name = self.item_mapping.get(item.name, "") + rule_ids = self.rule_item_dependencies[item.name] | self.rule_item_dependencies[mapped_name] + for rule_id in rule_ids: + if player_results.get(rule_id, None) is False: + del player_results[rule_id] + + return changed + + @override + def remove(self, state: CollectionState, item: Item) -> bool: + changed = super().remove(state, item) + if not changed: + return changed + + player_results = cast(dict[int, bool], state.rule_builder_cache[self.player]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + if self.rule_item_dependencies: + mapped_name = self.item_mapping.get(item.name, "") + rule_ids = self.rule_item_dependencies[item.name] | self.rule_item_dependencies[mapped_name] + for rule_id in rule_ids: + player_results.pop(rule_id, None) + + # clear all region dependent caches as none can be trusted + if self.rule_region_dependencies: + for rule_ids in self.rule_region_dependencies.values(): + for rule_id in rule_ids: + player_results.pop(rule_id, None) + + # clear all location dependent caches as they may have lost region access + if self.rule_location_dependencies: + for rule_ids in self.rule_location_dependencies.values(): + for rule_id in rule_ids: + player_results.pop(rule_id, None) + + # clear all entrance dependent caches as they may have lost region access + if self.rule_entrance_dependencies: + for rule_ids in self.rule_entrance_dependencies.values(): + for rule_id in rule_ids: + player_results.pop(rule_id, None) + + return changed + + @override + def reached_region(self, state: CollectionState, region: Region) -> None: + super().reached_region(state, region) + if self.rule_region_dependencies: + player_results = cast(dict[int, bool], state.rule_builder_cache[self.player]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + for rule_id in self.rule_region_dependencies[region.name]: + player_results.pop(rule_id, None) + + +class CachedRuleBuilderLogicMixin(LogicMixin): + multiworld: MultiWorld # pyright: ignore[reportUninitializedInstanceVariable] + rule_builder_cache: dict[int, dict[int, bool]] # pyright: ignore[reportUninitializedInstanceVariable] + + def init_mixin(self, multiworld: "MultiWorld") -> None: + players = multiworld.get_all_ids() + self.rule_builder_cache = {player: {} for player in players} + + def copy_mixin(self, new_state: "CachedRuleBuilderLogicMixin") -> "CachedRuleBuilderLogicMixin": + new_state.rule_builder_cache = { + player: player_results.copy() for player, player_results in self.rule_builder_cache.items() + } + return new_state diff --git a/rule_builder/options.py b/rule_builder/options.py new file mode 100644 index 0000000000..979a72315e --- /dev/null +++ b/rule_builder/options.py @@ -0,0 +1,91 @@ +import dataclasses +import importlib +import operator +from collections.abc import Callable, Iterable +from typing import Any, Final, Literal, Self, cast + +from typing_extensions import override + +from Options import CommonOptions, Option + +Operator = Literal["eq", "ne", "gt", "lt", "ge", "le", "contains", "in"] + +OPERATORS: Final[dict[Operator, Callable[..., bool]]] = { + "eq": operator.eq, + "ne": operator.ne, + "gt": operator.gt, + "lt": operator.lt, + "ge": operator.ge, + "le": operator.le, + "contains": operator.contains, + "in": operator.contains, +} +OPERATOR_STRINGS: Final[dict[Operator, str]] = { + "eq": "==", + "ne": "!=", + "gt": ">", + "lt": "<", + "ge": ">=", + "le": "<=", +} +REVERSE_OPERATORS: Final[tuple[Operator, ...]] = ("in",) + + +@dataclasses.dataclass(frozen=True) +class OptionFilter: + option: type[Option[Any]] + value: Any + operator: Operator = "eq" + + def to_dict(self) -> dict[str, Any]: + """Returns a JSON compatible dict representation of this option filter""" + return { + "option": f"{self.option.__module__}.{self.option.__name__}", + "value": self.value, + "operator": self.operator, + } + + def check(self, options: CommonOptions) -> bool: + """Tests the given options dataclass to see if it passes this option filter""" + option_name = next( + (name for name, cls in options.__class__.type_hints.items() if cls is self.option), + None, + ) + if option_name is None: + raise ValueError(f"Cannot find option {self.option.__name__} in options class {options.__class__.__name__}") + opt = cast(Option[Any] | None, getattr(options, option_name, None)) + if opt is None: + raise ValueError(f"Invalid option: {option_name}") + + fn = OPERATORS[self.operator] + return fn(self.value, opt) if self.operator in REVERSE_OPERATORS else fn(opt, self.value) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Self: + """Returns a new OptionFilter instance from a dict representation""" + if "option" not in data or "value" not in data: + raise ValueError("Missing required value and/or option") + + option_path = data["option"] + try: + option_mod_name, option_cls_name = option_path.rsplit(".", 1) + option_module = importlib.import_module(option_mod_name) + option = getattr(option_module, option_cls_name, None) + except (ValueError, ImportError) as e: + raise ValueError(f"Cannot parse option '{option_path}'") from e + if option is None or not issubclass(option, Option): + raise ValueError(f"Invalid option '{option_path}' returns type '{option}' instead of Option subclass") + + value = data["value"] + operator = data.get("operator", "eq") + return cls(option=cast(type[Option[Any]], option), value=value, operator=operator) + + @classmethod + def multiple_from_dict(cls, data: Iterable[dict[str, Any]]) -> tuple[Self, ...]: + """Returns a tuple of OptionFilters instances from an iterable of dict representations""" + return tuple(cls.from_dict(o) for o in data) + + @override + def __str__(self) -> str: + op = OPERATOR_STRINGS.get(self.operator, self.operator) + return f"{self.option.__name__} {op} {self.value}" diff --git a/rule_builder/rules.py b/rule_builder/rules.py new file mode 100644 index 0000000000..0e9396fb60 --- /dev/null +++ b/rule_builder/rules.py @@ -0,0 +1,1791 @@ +import dataclasses +from collections.abc import Callable, Iterable, Mapping +from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Never, Self, cast + +from typing_extensions import TypeVar, dataclass_transform, override + +from BaseClasses import CollectionState +from NetUtils import JSONMessagePart + +from .options import OptionFilter + +if TYPE_CHECKING: + from worlds.AutoWorld import World + + TWorld = TypeVar("TWorld", bound=World, contravariant=True, default=World) # noqa: PLC0105 +else: + TWorld = TypeVar("TWorld") + + +def _create_hash_fn(resolved_rule_cls: "CustomRuleRegister") -> Callable[..., int]: + def hash_impl(self: "Rule.Resolved") -> int: + return hash( + ( + self.__class__.__module__, + self.rule_name, + *[getattr(self, f.name) for f in dataclasses.fields(self)], + ) + ) + + hash_impl.__qualname__ = f"{resolved_rule_cls.__qualname__}.__hash__" + return hash_impl + + +@dataclass_transform(frozen_default=True, field_specifiers=(dataclasses.field, dataclasses.Field)) +class CustomRuleRegister(type): + """A metaclass to contain world custom rules and automatically convert resolved rules to frozen dataclasses""" + + resolved_rules: ClassVar[dict[int, "Rule.Resolved"]] = {} + """A cached of resolved rules to turn each unique one into a singleton""" + + custom_rules: ClassVar[dict[str, dict[str, type["Rule[Any]"]]]] = {} + """A mapping of game name to mapping of rule name to rule class to hold custom rules implemented by worlds""" + + rule_name: str = "Rule" + """The string name of a rule, must be unique per game""" + + def __new__( + cls, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + /, + **kwds: dict[str, Any], + ) -> type["CustomRuleRegister"]: + new_cls = super().__new__(cls, name, bases, namespace, **kwds) + new_cls.__hash__ = _create_hash_fn(new_cls) + rule_name = new_cls.__qualname__ + if rule_name.endswith(".Resolved"): + rule_name = rule_name[:-9] + new_cls.rule_name = rule_name + return dataclasses.dataclass(frozen=True)(new_cls) + + @override + def __call__(cls, *args: Any, **kwds: Any) -> Any: + rule = super().__call__(*args, **kwds) + rule_hash = hash(rule) + if rule_hash in cls.resolved_rules: + return cls.resolved_rules[rule_hash] + cls.resolved_rules[rule_hash] = rule + return rule + + @classmethod + def get_rule_cls(cls, game_name: str, rule_name: str) -> type["Rule[Any]"]: + """Returns the world-registered or default rule with the given name""" + custom_rule_classes = cls.custom_rules.get(game_name, {}) + if rule_name not in DEFAULT_RULES and rule_name not in custom_rule_classes: + raise ValueError(f"Rule '{rule_name}' for game '{game_name}' not found") + return custom_rule_classes.get(rule_name) or DEFAULT_RULES[rule_name] + + +@dataclasses.dataclass() +class Rule(Generic[TWorld]): + """Base class for a static rule used to generate an access rule""" + + options: Iterable[OptionFilter] = dataclasses.field(default=(), kw_only=True) + """An iterable of OptionFilters to restrict what options are required for this rule to be active""" + + filtered_resolution: bool = dataclasses.field(default=False, kw_only=True) + """If this rule should default to True or False when filtered by its options""" + + game_name: ClassVar[str] + """The name of the game this rule belongs to, default rules belong to 'Archipelago'""" + + def __post_init__(self) -> None: + if not isinstance(self.options, tuple): + self.options = tuple(self.options) + + def _instantiate(self, world: TWorld) -> "Resolved": + """Create a new resolved rule for this world""" + return self.Resolved(player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False)) + + def resolve(self, world: TWorld) -> "Resolved": + """Resolve a rule with the given world""" + for option_filter in self.options: + if not option_filter.check(world.options): + return True_().resolve(world) if self.filtered_resolution else False_().resolve(world) + return self._instantiate(world) + + def to_dict(self) -> dict[str, Any]: + """Returns a JSON compatible dict representation of this rule""" + args = { + field.name: getattr(self, field.name, None) + for field in dataclasses.fields(self) + if field.name not in ("options", "filtered_resolution") + } + return { + "rule": self.__class__.__qualname__, + "options": [o.to_dict() for o in self.options], + "filtered_resolution": self.filtered_resolution, + "args": args, + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + """Returns a new instance of this rule from a serialized dict representation""" + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(**data.get("args", {}), options=options, filtered_resolution=data.get("filtered_resolution", False)) + + def __and__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]": + """Combines two rules or a rule and an option filter into an And rule""" + if isinstance(other, OptionFilter): + other = (other,) + if isinstance(other, Iterable): + if not other: + return self + return Filtered(self, options=other) + if self.options == other.options: + if isinstance(self, And): + if isinstance(other, And): + return And(*self.children, *other.children, options=self.options) + return And(*self.children, other, options=self.options) + if isinstance(other, And): + return And(self, *other.children, options=other.options) + return And(self, other) + + def __rand__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]": + return self.__and__(other) + + def __or__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]": + """Combines two rules or a rule and an option filter into an Or rule""" + if isinstance(other, OptionFilter): + other = (other,) + if isinstance(other, Iterable): + if not other: + return self + return Or(self, True_(options=other)) + if self.options == other.options: + if isinstance(self, Or): + if isinstance(other, Or): + return Or(*self.children, *other.children, options=self.options) + return Or(*self.children, other, options=self.options) + if isinstance(other, Or): + return Or(self, *other.children, options=self.options) + return Or(self, other) + + def __ror__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]": + return self.__or__(other) + + def __bool__(self) -> Never: + """Safeguard to prevent devs from mistakenly doing `rule1 and rule2` and getting the wrong result""" + raise TypeError("Use & or | to combine rules, or use `is not None` for boolean tests") + + @override + def __str__(self) -> str: + options = f"options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({options})" + + @classmethod + def __init_subclass__(cls, /, game: str) -> None: + if game != "Archipelago": + custom_rules = CustomRuleRegister.custom_rules.setdefault(game, {}) + if cls.__qualname__ in custom_rules: + raise TypeError(f"Rule {cls.__qualname__} has already been registered for game {game}") + custom_rules[cls.__qualname__] = cls + elif cls.__module__ != "rule_builder.rules": + raise TypeError("You cannot define custom rules for the base Archipelago world") + cls.game_name = game + + class Resolved(metaclass=CustomRuleRegister): + """A resolved rule for a given world that can be used as an access rule""" + + _: dataclasses.KW_ONLY + + player: int + """The player this rule is for""" + + caching_enabled: bool = dataclasses.field(repr=False, default=False, kw_only=True) + """If the world this rule is for has caching enabled""" + + force_recalculate: ClassVar[bool] = False + """Forces this rule to be recalculated every time it is evaluated. + Forces any parent composite rules containing this rule to also be recalculated. Implies skip_cache.""" + + skip_cache: ClassVar[bool] = False + """Skips the caching layer when evaluating this rule. + Composite rules will still respect the caching layer so dependencies functions should be implemented as normal. + Set to True when rule calculation is trivial.""" + + always_true: ClassVar[bool] = False + """Whether this rule always evaluates to True, used to short-circuit logic""" + + always_false: ClassVar[bool] = False + """Whether this rule always evaluates to True, used to short-circuit logic""" + + def __post_init__(self) -> None: + object.__setattr__( + self, + "caching_enabled", + self.caching_enabled and not self.force_recalculate and not self.skip_cache, + ) + + def __call__(self, state: CollectionState) -> bool: + """Evaluate this rule's result with the given state, using the cached value if possible""" + if not self.caching_enabled: + return self._evaluate(state) + + player_results = cast(dict[int, bool], state.rule_builder_cache[self.player]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + cached_result = player_results.get(id(self)) + if cached_result is not None: + return cached_result + + result = self._evaluate(state) + player_results[id(self)] = result + return result + + def _evaluate(self, state: CollectionState) -> bool: + """Calculate this rule's result with the given state""" + ... + + def item_dependencies(self) -> dict[str, set[int]]: + """Returns a mapping of item name to set of object ids, used for cache invalidation""" + return {} + + def region_dependencies(self) -> dict[str, set[int]]: + """Returns a mapping of region name to set of object ids, + used for indirect connections and cache invalidation""" + return {} + + def location_dependencies(self) -> dict[str, set[int]]: + """Returns a mapping of location name to set of object ids, used for cache invalidation""" + return {} + + def entrance_dependencies(self) -> dict[str, set[int]]: + """Returns a mapping of entrance name to set of object ids, used for cache invalidation""" + return {} + + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + """Returns a list of printJSON messages that explain the logic for this rule""" + return [{"type": "text", "text": self.rule_name}] + + def explain_str(self, state: CollectionState | None = None) -> str: + """Returns a human readable string describing this rule""" + return str(self) + + @override + def __str__(self) -> str: + return self.rule_name + + +@dataclasses.dataclass() +class True_(Rule[TWorld], game="Archipelago"): # noqa: N801 + """A rule that always returns True""" + + class Resolved(Rule.Resolved): + always_true: ClassVar[bool] = True + skip_cache: ClassVar[bool] = True + + @override + def _evaluate(self, state: CollectionState) -> bool: + return True + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + return [{"type": "color", "color": "green", "text": "True"}] + + @override + def __str__(self) -> str: + return "True" + + +@dataclasses.dataclass() +class False_(Rule[TWorld], game="Archipelago"): # noqa: N801 + """A rule that always returns False""" + + class Resolved(Rule.Resolved): + always_false: ClassVar[bool] = True + skip_cache: ClassVar[bool] = True + + @override + def _evaluate(self, state: CollectionState) -> bool: + return False + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + return [{"type": "color", "color": "salmon", "text": "False"}] + + @override + def __str__(self) -> str: + return "False" + + +@dataclasses.dataclass(init=False) +class NestedRule(Rule[TWorld], game="Archipelago"): + """A base rule class that takes an iterable of other rules as an argument and does logic based on them""" + + children: tuple[Rule[TWorld], ...] + """The child rules this rule's logic is based on""" + + def __init__(self, *children: Rule[TWorld], options: Iterable[OptionFilter] = ()) -> None: + super().__init__(options=options) + self.children = children + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + children = [c.resolve(world) for c in self.children] + return self.Resolved( + tuple(children), + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def to_dict(self) -> dict[str, Any]: + data = super().to_dict() + del data["args"] + data["children"] = [c.to_dict() for c in self.children] + return data + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + children = [world_cls.rule_from_dict(c) for c in data.get("children", ())] + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(*children, options=options) + + @override + def __str__(self) -> str: + children = ", ".join(str(c) for c in self.children) + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({children}{options})" + + class Resolved(Rule.Resolved): + children: tuple[Rule.Resolved, ...] + + def __post_init__(self) -> None: + object.__setattr__( + self, + "force_recalculate", + self.force_recalculate or any(c.force_recalculate for c in self.children), + ) + super().__post_init__() + + @override + def item_dependencies(self) -> dict[str, set[int]]: + combined_deps: dict[str, set[int]] = {} + for child in self.children: + for item_name, rules in child.item_dependencies().items(): + if item_name in combined_deps: + combined_deps[item_name] |= rules + else: + combined_deps[item_name] = {id(self), *rules} + return combined_deps + + @override + def region_dependencies(self) -> dict[str, set[int]]: + combined_deps: dict[str, set[int]] = {} + for child in self.children: + for region_name, rules in child.region_dependencies().items(): + if region_name in combined_deps: + combined_deps[region_name] |= rules + else: + combined_deps[region_name] = {id(self), *rules} + return combined_deps + + @override + def location_dependencies(self) -> dict[str, set[int]]: + combined_deps: dict[str, set[int]] = {} + for child in self.children: + for location_name, rules in child.location_dependencies().items(): + if location_name in combined_deps: + combined_deps[location_name] |= rules + else: + combined_deps[location_name] = {id(self), *rules} + return combined_deps + + @override + def entrance_dependencies(self) -> dict[str, set[int]]: + combined_deps: dict[str, set[int]] = {} + for child in self.children: + for entrance_name, rules in child.entrance_dependencies().items(): + if entrance_name in combined_deps: + combined_deps[entrance_name] |= rules + else: + combined_deps[entrance_name] = {id(self), *rules} + return combined_deps + + +@dataclasses.dataclass(init=False) +class And(NestedRule[TWorld], game="Archipelago"): + """A rule that only returns true when all child rules evaluate as true""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + children_to_process = [c.resolve(world) for c in self.children] + clauses: list[Rule.Resolved] = [] + items: dict[str, int] = {} + true_rule: Rule.Resolved | None = None + + while children_to_process: + child = children_to_process.pop(0) + if child.always_false: + # false always wins + return child + if child.always_true: + # dedupe trues + true_rule = child + continue + if isinstance(child, And.Resolved): + children_to_process.extend(child.children) + continue + + if isinstance(child, Has.Resolved): + if child.item_name not in items or items[child.item_name] < child.count: + items[child.item_name] = child.count + elif isinstance(child, HasAll.Resolved): + for item in child.item_names: + if item not in items: + items[item] = 1 + elif isinstance(child, HasAllCounts.Resolved): + for item, count in child.item_counts: + if item not in items or items[item] < count: + items[item] = count + else: + clauses.append(child) + + if not clauses and not items: + return true_rule or False_().resolve(world) + + if len(items) == 1: + item, count = next(iter(items.items())) + clauses.append(Has(item, count).resolve(world)) + elif items and all(count == 1 for count in items.values()): + clauses.append(HasAll(*items).resolve(world)) + elif items: + clauses.append(HasAllCounts(items).resolve(world)) + + if len(clauses) == 1: + return clauses[0] + + return And.Resolved( + tuple(clauses), + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + class Resolved(NestedRule.Resolved): + @override + def _evaluate(self, state: CollectionState) -> bool: + for rule in self.children: + if not rule(state): + return False + return True + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [{"type": "text", "text": "("}] + for i, child in enumerate(self.children): + if i > 0: + messages.append({"type": "text", "text": " & "}) + messages.extend(child.explain_json(state)) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + clauses = " & ".join([c.explain_str(state) for c in self.children]) + return f"({clauses})" + + @override + def __str__(self) -> str: + clauses = " & ".join([str(c) for c in self.children]) + return f"({clauses})" + + +@dataclasses.dataclass(init=False) +class Or(NestedRule[TWorld], game="Archipelago"): + """A rule that returns true when any child rule evaluates as true""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + children_to_process = [c.resolve(world) for c in self.children] + clauses: list[Rule.Resolved] = [] + items: dict[str, int] = {} + + while children_to_process: + child = children_to_process.pop(0) + if child.always_true: + # true always wins + return child + if child.always_false: + # falses can be ignored + continue + if isinstance(child, Or.Resolved): + children_to_process.extend(child.children) + continue + + if isinstance(child, Has.Resolved): + if child.item_name not in items or child.count < items[child.item_name]: + items[child.item_name] = child.count + elif isinstance(child, HasAny.Resolved): + for item in child.item_names: + items[item] = 1 + elif isinstance(child, HasAnyCount.Resolved): + for item, count in child.item_counts: + if item not in items or items[item] < count: + items[item] = count + else: + clauses.append(child) + + if not clauses and not items: + return False_().resolve(world) + + if len(items) == 1: + item, count = next(iter(items.items())) + clauses.append(Has(item, count).resolve(world)) + elif items and all(count == 1 for count in items.values()): + clauses.append(HasAny(*items).resolve(world)) + elif items: + clauses.append(HasAnyCount(items).resolve(world)) + + if len(clauses) == 1: + return clauses[0] + + return Or.Resolved( + tuple(clauses), + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + class Resolved(NestedRule.Resolved): + @override + def _evaluate(self, state: CollectionState) -> bool: + for rule in self.children: + if rule(state): + return True + return False + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [{"type": "text", "text": "("}] + for i, child in enumerate(self.children): + if i > 0: + messages.append({"type": "text", "text": " | "}) + messages.extend(child.explain_json(state)) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + clauses = " | ".join([c.explain_str(state) for c in self.children]) + return f"({clauses})" + + @override + def __str__(self) -> str: + clauses = " | ".join([str(c) for c in self.children]) + return f"({clauses})" + + +@dataclasses.dataclass() +class WrapperRule(Rule[TWorld], game="Archipelago"): + """A base rule class that wraps another rule to provide extra logic or data""" + + child: Rule[TWorld] + """The child rule being wrapped""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + return self.Resolved( + self.child.resolve(world), + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def to_dict(self) -> dict[str, Any]: + data = super().to_dict() + del data["args"] + data["child"] = self.child.to_dict() + return data + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + child = data.get("child") + if child is None: + raise ValueError("Child rule cannot be None") + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(world_cls.rule_from_dict(child), options=options) + + @override + def __str__(self) -> str: + return f"{self.__class__.__name__}[{self.child}]" + + class Resolved(Rule.Resolved): + child: Rule.Resolved + + def __post_init__(self) -> None: + object.__setattr__(self, "force_recalculate", self.force_recalculate or self.child.force_recalculate) + super().__post_init__() + + @override + def _evaluate(self, state: CollectionState) -> bool: + return self.child(state) + + @override + def item_dependencies(self) -> dict[str, set[int]]: + deps: dict[str, set[int]] = {} + for item_name, rules in self.child.item_dependencies().items(): + deps[item_name] = {id(self), *rules} + return deps + + @override + def region_dependencies(self) -> dict[str, set[int]]: + deps: dict[str, set[int]] = {} + for region_name, rules in self.child.region_dependencies().items(): + deps[region_name] = {id(self), *rules} + return deps + + @override + def location_dependencies(self) -> dict[str, set[int]]: + deps: dict[str, set[int]] = {} + for location_name, rules in self.child.location_dependencies().items(): + deps[location_name] = {id(self), *rules} + return deps + + @override + def entrance_dependencies(self) -> dict[str, set[int]]: + deps: dict[str, set[int]] = {} + for entrance_name, rules in self.child.entrance_dependencies().items(): + deps[entrance_name] = {id(self), *rules} + return deps + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [{"type": "text", "text": f"{self.rule_name} ["}] + messages.extend(self.child.explain_json(state)) + messages.append({"type": "text", "text": "]"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + return f"{self.rule_name}[{self.child.explain_str(state)}]" + + @override + def __str__(self) -> str: + return f"{self.rule_name}[{self.child}]" + + +@dataclasses.dataclass() +class Filtered(WrapperRule[TWorld], game="Archipelago"): + """A convenience rule to wrap an existing rule with an options filter""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + return self.child.resolve(world) + + +@dataclasses.dataclass() +class Has(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has at least `count` of a given item""" + + item_name: str + """The item to check for""" + + count: int = 1 + """The count the player is required to have""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + return self.Resolved( + self.item_name, + self.count, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + count = f", count={self.count}" if self.count > 1 else "" + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({self.item_name}{count}{options})" + + class Resolved(Rule.Resolved): + item_name: str + count: int = 1 + skip_cache: ClassVar[bool] = True + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has + return state.prog_items[self.player][self.item_name] >= self.count + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {self.item_name: set()} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + verb = "Missing " if state and not self(state) else "Has " + messages: list[JSONMessagePart] = [{"type": "text", "text": verb}] + if self.count > 1: + messages.append({"type": "color", "color": "cyan", "text": str(self.count)}) + messages.append({"type": "text", "text": "x "}) + if state: + color = "green" if self(state) else "salmon" + messages.append({"type": "color", "color": color, "text": self.item_name}) + else: + messages.append({"type": "item_name", "flags": 0b001, "text": self.item_name, "player": self.player}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + prefix = "Has" if self(state) else "Missing" + count = f"{self.count}x " if self.count > 1 else "" + return f"{prefix} {count}{self.item_name}" + + @override + def __str__(self) -> str: + count = f"{self.count}x " if self.count > 1 else "" + return f"Has {count}{self.item_name}" + + +@dataclasses.dataclass(init=False) +class HasAll(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has all of the given items""" + + item_names: tuple[str, ...] + """A tuple of item names to check for""" + + def __init__(self, *item_names: str, options: Iterable[OptionFilter] = ()) -> None: + super().__init__(options=options) + self.item_names = tuple(sorted(set(item_names))) + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + if len(self.item_names) == 0: + # match state.has_all + return True_().resolve(world) + if len(self.item_names) == 1: + return Has(self.item_names[0]).resolve(world) + return self.Resolved( + self.item_names, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + args = {**data.get("args", {})} + item_names = args.pop("item_names", ()) + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(*item_names, **args, options=options) + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({items}{options})" + + class Resolved(Rule.Resolved): + item_names: tuple[str, ...] + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_all + player_prog_items = state.prog_items[self.player] + for item in self.item_names: + if not player_prog_items[item]: + return False + return True + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item in self.item_names} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [] + if state is None: + messages = [ + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + ] + for i, item in enumerate(self.item_names): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "item_name", "flags": 0b001, "text": item, "player": self.player}) + messages.append({"type": "text", "text": ")"}) + return messages + + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + messages = [ + {"type": "text", "text": "Has " if not missing else "Missing "}, + {"type": "color", "color": "cyan", "text": "all" if not missing else "some"}, + {"type": "text", "text": " of ("}, + ] + if found: + messages.append({"type": "text", "text": "Found: "}) + for i, item in enumerate(found): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "green", "text": item}) + if missing: + messages.append({"type": "text", "text": "; "}) + + if missing: + messages.append({"type": "text", "text": "Missing: "}) + for i, item in enumerate(missing): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "salmon", "text": item}) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + prefix = "Has all" if self(state) else "Missing some" + found_str = f"Found: {', '.join(found)}" if found else "" + missing_str = f"Missing: {', '.join(missing)}" if missing else "" + infix = "; " if found and missing else "" + return f"{prefix} of ({found_str}{infix}{missing_str})" + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + return f"Has all of ({items})" + + +@dataclasses.dataclass(init=False) +class HasAny(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has at least one of the given items""" + + item_names: tuple[str, ...] + """A tuple of item names to check for""" + + def __init__(self, *item_names: str, options: Iterable[OptionFilter] = ()) -> None: + super().__init__(options=options) + self.item_names = tuple(sorted(set(item_names))) + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + if len(self.item_names) == 0: + # match state.has_any + return False_().resolve(world) + if len(self.item_names) == 1: + return Has(self.item_names[0]).resolve(world) + return self.Resolved( + self.item_names, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + args = {**data.get("args", {})} + item_names = args.pop("item_names", ()) + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(*item_names, **args, options=options) + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({items}{options})" + + class Resolved(Rule.Resolved): + item_names: tuple[str, ...] + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_any + player_prog_items = state.prog_items[self.player] + for item in self.item_names: + if player_prog_items[item]: + return True + return False + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item in self.item_names} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [] + if state is None: + messages = [ + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "any"}, + {"type": "text", "text": " of ("}, + ] + for i, item in enumerate(self.item_names): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "item_name", "flags": 0b001, "text": item, "player": self.player}) + messages.append({"type": "text", "text": ")"}) + return messages + + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + messages = [ + {"type": "text", "text": "Has " if found else "Missing "}, + {"type": "color", "color": "cyan", "text": "some" if found else "all"}, + {"type": "text", "text": " of ("}, + ] + if found: + messages.append({"type": "text", "text": "Found: "}) + for i, item in enumerate(found): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "green", "text": item}) + if missing: + messages.append({"type": "text", "text": "; "}) + + if missing: + messages.append({"type": "text", "text": "Missing: "}) + for i, item in enumerate(missing): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "salmon", "text": item}) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + prefix = "Has some" if self(state) else "Missing all" + found_str = f"Found: {', '.join(found)}" if found else "" + missing_str = f"Missing: {', '.join(missing)}" if missing else "" + infix = "; " if found and missing else "" + return f"{prefix} of ({found_str}{infix}{missing_str})" + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + return f"Has any of ({items})" + + +@dataclasses.dataclass() +class HasAllCounts(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has all of the specified counts of the given items""" + + item_counts: dict[str, int] + """A mapping of item name to count to check for""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + if len(self.item_counts) == 0: + # match state.has_all_counts + return True_().resolve(world) + if len(self.item_counts) == 1: + item = next(iter(self.item_counts)) + return Has(item, self.item_counts[item]).resolve(world) + return self.Resolved( + tuple(self.item_counts.items()), + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()]) + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({items}{options})" + + class Resolved(Rule.Resolved): + item_counts: tuple[tuple[str, int], ...] + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_all_counts + player_prog_items = state.prog_items[self.player] + for item, count in self.item_counts: + if player_prog_items[item] < count: + return False + return True + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item, _ in self.item_counts} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [] + if state is None: + messages = [ + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + ] + for i, (item, count) in enumerate(self.item_counts): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "item_name", "flags": 0b001, "text": item, "player": self.player}) + messages.append({"type": "text", "text": f" x{count}"}) + messages.append({"type": "text", "text": ")"}) + return messages + + found = [(item, count) for item, count in self.item_counts if state.has(item, self.player, count)] + missing = [(item, count) for item, count in self.item_counts if (item, count) not in found] + messages = [ + {"type": "text", "text": "Has " if not missing else "Missing "}, + {"type": "color", "color": "cyan", "text": "all" if not missing else "some"}, + {"type": "text", "text": " of ("}, + ] + if found: + messages.append({"type": "text", "text": "Found: "}) + for i, (item, count) in enumerate(found): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "green", "text": item}) + messages.append({"type": "text", "text": f" x{count}"}) + if missing: + messages.append({"type": "text", "text": "; "}) + + if missing: + messages.append({"type": "text", "text": "Missing: "}) + for i, (item, count) in enumerate(missing): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "salmon", "text": item}) + messages.append({"type": "text", "text": f" x{count}"}) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + found = [(item, count) for item, count in self.item_counts if state.has(item, self.player, count)] + missing = [(item, count) for item, count in self.item_counts if (item, count) not in found] + prefix = "Has all" if self(state) else "Missing some" + found_str = f"Found: {', '.join([f'{item} x{count}' for item, count in found])}" if found else "" + missing_str = f"Missing: {', '.join([f'{item} x{count}' for item, count in missing])}" if missing else "" + infix = "; " if found and missing else "" + return f"{prefix} of ({found_str}{infix}{missing_str})" + + @override + def __str__(self) -> str: + items = ", ".join([f"{item} x{count}" for item, count in self.item_counts]) + return f"Has all of ({items})" + + +@dataclasses.dataclass() +class HasAnyCount(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has any of the specified counts of the given items""" + + item_counts: dict[str, int] + """A mapping of item name to count to check for""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + if len(self.item_counts) == 0: + # match state.has_any_count + return False_().resolve(world) + if len(self.item_counts) == 1: + item = next(iter(self.item_counts)) + return Has(item, self.item_counts[item]).resolve(world) + return self.Resolved( + tuple(self.item_counts.items()), + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()]) + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({items}{options})" + + class Resolved(Rule.Resolved): + item_counts: tuple[tuple[str, int], ...] + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_any_count + player_prog_items = state.prog_items[self.player] + for item, count in self.item_counts: + if player_prog_items[item] >= count: + return True + return False + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item, _ in self.item_counts} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [] + if state is None: + messages = [ + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "any"}, + {"type": "text", "text": " of ("}, + ] + for i, (item, count) in enumerate(self.item_counts): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "item_name", "flags": 0b001, "text": item, "player": self.player}) + messages.append({"type": "text", "text": f" x{count}"}) + messages.append({"type": "text", "text": ")"}) + return messages + + found = [(item, count) for item, count in self.item_counts if state.has(item, self.player, count)] + missing = [(item, count) for item, count in self.item_counts if (item, count) not in found] + messages = [ + {"type": "text", "text": "Has " if found else "Missing "}, + {"type": "color", "color": "cyan", "text": "some" if found else "all"}, + {"type": "text", "text": " of ("}, + ] + if found: + messages.append({"type": "text", "text": "Found: "}) + for i, (item, count) in enumerate(found): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "green", "text": item}) + messages.append({"type": "text", "text": f" x{count}"}) + if missing: + messages.append({"type": "text", "text": "; "}) + + if missing: + messages.append({"type": "text", "text": "Missing: "}) + for i, (item, count) in enumerate(missing): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "salmon", "text": item}) + messages.append({"type": "text", "text": f" x{count}"}) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + found = [(item, count) for item, count in self.item_counts if state.has(item, self.player, count)] + missing = [(item, count) for item, count in self.item_counts if (item, count) not in found] + prefix = "Has some" if self(state) else "Missing all" + found_str = f"Found: {', '.join([f'{item} x{count}' for item, count in found])}" if found else "" + missing_str = f"Missing: {', '.join([f'{item} x{count}' for item, count in missing])}" if missing else "" + infix = "; " if found and missing else "" + return f"{prefix} of ({found_str}{infix}{missing_str})" + + @override + def __str__(self) -> str: + items = ", ".join([f"{item} x{count}" for item, count in self.item_counts]) + return f"Has any of ({items})" + + +@dataclasses.dataclass(init=False) +class HasFromList(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has at least `count` of the given items""" + + item_names: tuple[str, ...] + """A tuple of item names to check for""" + + count: int = 1 + """The number of items the player needs to have""" + + def __init__(self, *item_names: str, count: int = 1, options: Iterable[OptionFilter] = ()) -> None: + super().__init__(options=options) + self.item_names = tuple(sorted(set(item_names))) + self.count = count + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + if len(self.item_names) == 0: + # match state.has_from_list + return False_().resolve(world) + if len(self.item_names) == 1: + return Has(self.item_names[0], self.count).resolve(world) + return self.Resolved( + self.item_names, + self.count, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + args = {**data.get("args", {})} + item_names = args.pop("item_names", ()) + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(*item_names, **args, options=options) + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({items}, count={self.count}{options})" + + class Resolved(Rule.Resolved): + item_names: tuple[str, ...] + count: int = 1 + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_from_list + found = 0 + player_prog_items = state.prog_items[self.player] + for item_name in self.item_names: + found += player_prog_items[item_name] + if found >= self.count: + return True + return False + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item in self.item_names} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [] + if state is None: + messages = [ + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": str(self.count)}, + {"type": "text", "text": "x items from ("}, + ] + for i, item in enumerate(self.item_names): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "item_name", "flags": 0b001, "text": item, "player": self.player}) + messages.append({"type": "text", "text": ")"}) + return messages + + found_count = state.count_from_list(self.item_names, self.player) + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + color = "green" if found_count >= self.count else "salmon" + messages = [ + {"type": "text", "text": "Has "}, + { + "type": "color", + "color": color, + "text": f"{found_count}/{self.count}", + }, + {"type": "text", "text": " items from ("}, + ] + if found: + messages.append({"type": "text", "text": "Found: "}) + for i, item in enumerate(found): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "green", "text": item}) + if missing: + messages.append({"type": "text", "text": "; "}) + + if missing: + messages.append({"type": "text", "text": "Missing: "}) + for i, item in enumerate(missing): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "salmon", "text": item}) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + found_count = state.count_from_list(self.item_names, self.player) + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + found_str = f"Found: {', '.join(found)}" if found else "" + missing_str = f"Missing: {', '.join(missing)}" if missing else "" + infix = "; " if found and missing else "" + return f"Has {found_count}/{self.count} items from ({found_str}{infix}{missing_str})" + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + count = f"{self.count}x items" if self.count > 1 else "an item" + return f"Has {count} from ({items})" + + +@dataclasses.dataclass(init=False) +class HasFromListUnique(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has at least `count` of the given items, ignoring duplicates of the same item""" + + item_names: tuple[str, ...] + """A tuple of item names to check for""" + + count: int = 1 + """The number of items the player needs to have""" + + def __init__(self, *item_names: str, count: int = 1, options: Iterable[OptionFilter] = ()) -> None: + super().__init__(options=options) + self.item_names = tuple(sorted(set(item_names))) + self.count = count + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + if len(self.item_names) == 0 or len(self.item_names) < self.count: + # match state.has_from_list_unique + return False_().resolve(world) + if len(self.item_names) == 1: + return Has(self.item_names[0]).resolve(world) + return self.Resolved( + self.item_names, + self.count, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + args = {**data.get("args", {})} + item_names = args.pop("item_names", ()) + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(*item_names, **args, options=options) + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({items}, count={self.count}{options})" + + class Resolved(Rule.Resolved): + item_names: tuple[str, ...] + count: int = 1 + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_from_list_unique + found = 0 + player_prog_items = state.prog_items[self.player] + for item_name in self.item_names: + found += player_prog_items[item_name] > 0 + if found >= self.count: + return True + return False + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item in self.item_names} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [] + if state is None: + messages = [ + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": str(self.count)}, + {"type": "text", "text": "x unique items from ("}, + ] + for i, item in enumerate(self.item_names): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "item_name", "flags": 0b001, "text": item, "player": self.player}) + messages.append({"type": "text", "text": ")"}) + return messages + + found_count = state.count_from_list_unique(self.item_names, self.player) + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + color = "green" if found_count >= self.count else "salmon" + messages = [ + {"type": "text", "text": "Has "}, + {"type": "color", "color": color, "text": f"{found_count}/{self.count}"}, + {"type": "text", "text": " unique items from ("}, + ] + if found: + messages.append({"type": "text", "text": "Found: "}) + for i, item in enumerate(found): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "green", "text": item}) + if missing: + messages.append({"type": "text", "text": "; "}) + + if missing: + messages.append({"type": "text", "text": "Missing: "}) + for i, item in enumerate(missing): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.append({"type": "color", "color": "salmon", "text": item}) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + found_count = state.count_from_list_unique(self.item_names, self.player) + found = [item for item in self.item_names if state.has(item, self.player)] + missing = [item for item in self.item_names if item not in found] + found_str = f"Found: {', '.join(found)}" if found else "" + missing_str = f"Missing: {', '.join(missing)}" if missing else "" + infix = "; " if found and missing else "" + return f"Has {found_count}/{self.count} unique items from ({found_str}{infix}{missing_str})" + + @override + def __str__(self) -> str: + items = ", ".join(self.item_names) + count = f"{self.count}x unique items" if self.count > 1 else "a unique item" + return f"Has {count} from ({items})" + + +@dataclasses.dataclass() +class HasGroup(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has at least `count` of the items present in the specified item group""" + + item_name_group: str + """The name of the item group containing the items""" + + count: int = 1 + """The number of items the player needs to have""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + item_names = tuple(sorted(world.item_name_groups[self.item_name_group])) + return self.Resolved( + self.item_name_group, + item_names, + self.count, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + count = f", count={self.count}" if self.count > 1 else "" + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({self.item_name_group}{count}{options})" + + class Resolved(Rule.Resolved): + item_name_group: str + item_names: tuple[str, ...] + count: int = 1 + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_group + found = 0 + player_prog_items = state.prog_items[self.player] + for item_name in self.item_names: + found += player_prog_items[item_name] + if found >= self.count: + return True + return False + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item in self.item_names} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [{"type": "text", "text": "Has "}] + if state is None: + messages.append({"type": "color", "color": "cyan", "text": str(self.count)}) + else: + count = state.count_group(self.item_name_group, self.player) + color = "green" if count >= self.count else "salmon" + messages.append({"type": "color", "color": color, "text": f"{count}/{self.count}"}) + messages.append({"type": "text", "text": " items from "}) + messages.append({"type": "color", "color": "cyan", "text": self.item_name_group}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + count = state.count_group(self.item_name_group, self.player) + return f"Has {count}/{self.count} items from {self.item_name_group}" + + @override + def __str__(self) -> str: + count = f"{self.count}x items" if self.count > 1 else "an item" + return f"Has {count} from {self.item_name_group}" + + +@dataclasses.dataclass() +class HasGroupUnique(Rule[TWorld], game="Archipelago"): + """A rule that checks if the player has at least `count` of the items present + in the specified item group, ignoring duplicates of the same item""" + + item_name_group: str + """The name of the item group containing the items""" + + count: int = 1 + """The number of items the player needs to have""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + item_names = tuple(sorted(world.item_name_groups[self.item_name_group])) + return self.Resolved( + self.item_name_group, + item_names, + self.count, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + count = f", count={self.count}" if self.count > 1 else "" + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({self.item_name_group}{count}{options})" + + class Resolved(Rule.Resolved): + item_name_group: str + item_names: tuple[str, ...] + count: int = 1 + + @override + def _evaluate(self, state: CollectionState) -> bool: + # implementation based on state.has_group_unique + found = 0 + player_prog_items = state.prog_items[self.player] + for item_name in self.item_names: + found += player_prog_items[item_name] > 0 + if found >= self.count: + return True + return False + + @override + def item_dependencies(self) -> dict[str, set[int]]: + return {item: {id(self)} for item in self.item_names} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [{"type": "text", "text": "Has "}] + if state is None: + messages.append({"type": "color", "color": "cyan", "text": str(self.count)}) + else: + count = state.count_group_unique(self.item_name_group, self.player) + color = "green" if count >= self.count else "salmon" + messages.append({"type": "color", "color": color, "text": f"{count}/{self.count}"}) + messages.append({"type": "text", "text": " unique items from "}) + messages.append({"type": "color", "color": "cyan", "text": self.item_name_group}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + count = state.count_group_unique(self.item_name_group, self.player) + return f"Has {count}/{self.count} unique items from {self.item_name_group}" + + @override + def __str__(self) -> str: + count = f"{self.count}x unique items" if self.count > 1 else "a unique item" + return f"Has {count} from {self.item_name_group}" + + +@dataclasses.dataclass() +class CanReachLocation(Rule[TWorld], game="Archipelago"): + """A rule that checks if the given location is reachable by the current player""" + + location_name: str + """The name of the location to test access to""" + + parent_region_name: str = "" + """The name of the location's parent region. If not specified it will be resolved when the rule is resolved""" + + skip_indirect_connection: bool = False + """Skip finding the location's parent region. + Do not use this if this rule is for an entrance and explicit_indirect_conditions is True + """ + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + parent_region_name = self.parent_region_name + if not parent_region_name and not self.skip_indirect_connection: + location = world.get_location(self.location_name) + if not location.parent_region: + raise ValueError(f"Location {location.name} has no parent region") + parent_region_name = location.parent_region.name + return self.Resolved( + self.location_name, + parent_region_name, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({self.location_name}{options})" + + class Resolved(Rule.Resolved): + location_name: str + parent_region_name: str + + @override + def _evaluate(self, state: CollectionState) -> bool: + return state.can_reach_location(self.location_name, self.player) + + @override + def region_dependencies(self) -> dict[str, set[int]]: + if self.parent_region_name: + return {self.parent_region_name: {id(self)}} + return {} + + @override + def location_dependencies(self) -> dict[str, set[int]]: + return {self.location_name: {id(self)}} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + if state is None: + verb = "Can reach" + elif self(state): + verb = "Reached" + else: + verb = "Cannot reach" + return [ + {"type": "text", "text": f"{verb} location "}, + {"type": "location_name", "text": self.location_name, "player": self.player}, + ] + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + prefix = "Reached" if self(state) else "Cannot reach" + return f"{prefix} location {self.location_name}" + + @override + def __str__(self) -> str: + return f"Can reach location {self.location_name}" + + +@dataclasses.dataclass() +class CanReachRegion(Rule[TWorld], game="Archipelago"): + """A rule that checks if the given region is reachable by the current player""" + + region_name: str + """The name of the region to test access to""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + return self.Resolved( + self.region_name, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({self.region_name}{options})" + + class Resolved(Rule.Resolved): + region_name: str + + @override + def _evaluate(self, state: CollectionState) -> bool: + return state.can_reach_region(self.region_name, self.player) + + @override + def region_dependencies(self) -> dict[str, set[int]]: + return {self.region_name: {id(self)}} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + if state is None: + verb = "Can reach" + elif self(state): + verb = "Reached" + else: + verb = "Cannot reach" + return [ + {"type": "text", "text": f"{verb} region "}, + {"type": "color", "color": "yellow", "text": self.region_name}, + ] + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + prefix = "Reached" if self(state) else "Cannot reach" + return f"{prefix} region {self.region_name}" + + @override + def __str__(self) -> str: + return f"Can reach region {self.region_name}" + + +@dataclasses.dataclass() +class CanReachEntrance(Rule[TWorld], game="Archipelago"): + """A rule that checks if the given entrance is reachable by the current player""" + + entrance_name: str + """The name of the entrance to test access to""" + + parent_region_name: str = "" + """The name of the entrance's parent region. If not specified it will be resolved when the rule is resolved""" + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + parent_region_name = self.parent_region_name + if not parent_region_name: + entrance = world.get_entrance(self.entrance_name) + if not entrance.parent_region: + raise ValueError(f"Entrance {entrance.name} has no parent region") + parent_region_name = entrance.parent_region.name + return self.Resolved( + self.entrance_name, + parent_region_name, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def __str__(self) -> str: + options = f", options={self.options}" if self.options else "" + return f"{self.__class__.__name__}({self.entrance_name}{options})" + + class Resolved(Rule.Resolved): + entrance_name: str + parent_region_name: str + + @override + def _evaluate(self, state: CollectionState) -> bool: + return state.can_reach_entrance(self.entrance_name, self.player) + + @override + def region_dependencies(self) -> dict[str, set[int]]: + if self.parent_region_name: + return {self.parent_region_name: {id(self)}} + return {} + + @override + def entrance_dependencies(self) -> dict[str, set[int]]: + return {self.entrance_name: {id(self)}} + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + if state is None: + verb = "Can reach" + elif self(state): + verb = "Reached" + else: + verb = "Cannot reach" + return [ + {"type": "text", "text": f"{verb} entrance "}, + {"type": "entrance_name", "text": self.entrance_name, "player": self.player}, + ] + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + if state is None: + return str(self) + prefix = "Reached" if self(state) else "Cannot reach" + return f"{prefix} entrance {self.entrance_name}" + + @override + def __str__(self) -> str: + return f"Can reach entrance {self.entrance_name}" + + +DEFAULT_RULES: "Final[dict[str, type[Rule[World]]]]" = { + rule_name: cast("type[Rule[World]]", rule_class) + for rule_name, rule_class in locals().items() + if isinstance(rule_class, type) and issubclass(rule_class, Rule) and rule_class is not Rule +} diff --git a/test/general/test_rule_builder.py b/test/general/test_rule_builder.py new file mode 100644 index 0000000000..52248b6047 --- /dev/null +++ b/test/general/test_rule_builder.py @@ -0,0 +1,1336 @@ +import unittest +from dataclasses import dataclass, fields +from typing import Any, ClassVar, cast + +from typing_extensions import override + +from BaseClasses import CollectionState, Item, ItemClassification, Location, MultiWorld, Region +from NetUtils import JSONMessagePart +from Options import Choice, FreeText, Option, OptionSet, PerGameCommonOptions, Toggle +from rule_builder.cached_world import CachedRuleBuilderWorld +from rule_builder.options import Operator, OptionFilter +from rule_builder.rules import ( + And, + CanReachEntrance, + CanReachLocation, + CanReachRegion, + False_, + Filtered, + Has, + HasAll, + HasAllCounts, + HasAny, + HasAnyCount, + HasFromList, + HasFromListUnique, + HasGroup, + HasGroupUnique, + Or, + Rule, + True_, +) +from test.general import setup_solo_multiworld +from test.param import classvar_matrix +from worlds.AutoWorld import AutoWorldRegister, World + + +class CachedCollectionState(CollectionState): + rule_builder_cache: dict[int, dict[int, bool]] # pyright: ignore[reportUninitializedInstanceVariable] + + +class ToggleOption(Toggle): + auto_display_name = True + + +class ChoiceOption(Choice): + auto_display_name = True + option_first = 0 + option_second = 1 + option_third = 2 + default = 0 + + +class FreeTextOption(FreeText): + auto_display_name = True + + +class SetOption(OptionSet): + auto_display_name = True + valid_keys: ClassVar[set[str]] = {"one", "two", "three"} # pyright: ignore[reportIncompatibleVariableOverride] + + +@dataclass +class RuleBuilderOptions(PerGameCommonOptions): + toggle_option: ToggleOption + choice_option: ChoiceOption + text_option: FreeTextOption + set_option: SetOption + + +GAME_NAME = "Rule Builder Test Game" +LOC_COUNT = 20 + + +class RuleBuilderItem(Item): + game = GAME_NAME + + +class RuleBuilderLocation(Location): + game = GAME_NAME + + +class RuleBuilderTestCase(unittest.TestCase): + old_world_types: dict[str, type[World]] # pyright: ignore[reportUninitializedInstanceVariable] + world_cls: type[World] # pyright: ignore[reportUninitializedInstanceVariable] + + @override + def setUp(self) -> None: + self.old_world_types = AutoWorldRegister.world_types.copy() + self._create_world_class() + + @override + def tearDown(self) -> None: + AutoWorldRegister.world_types = self.old_world_types + assert GAME_NAME not in AutoWorldRegister.world_types + + def _create_world_class(self) -> None: + class RuleBuilderWorld(World): + game = GAME_NAME + item_name_to_id: ClassVar = {f"Item {i}": i for i in range(1, LOC_COUNT + 1)} + location_name_to_id: ClassVar = {f"Location {i}": i for i in range(1, LOC_COUNT + 1)} + item_name_groups: ClassVar = { + "Group 1": {"Item 1", "Item 2", "Item 3"}, + "Group 2": {"Item 4", "Item 5"}, + } + hidden = True + options_dataclass = RuleBuilderOptions + options: RuleBuilderOptions # pyright: ignore[reportIncompatibleVariableOverride] + origin_region_name = "Region 1" + + @override + def create_item(self, name: str) -> RuleBuilderItem: + classification = ItemClassification.filler if name == "Filler" else ItemClassification.progression + return RuleBuilderItem(name, classification, self.item_name_to_id[name], self.player) + + @override + def get_filler_item_name(self) -> str: + return "Filler" + + self.world_cls = RuleBuilderWorld + + +class CachedRuleBuilderTestCase(RuleBuilderTestCase): + @override + def _create_world_class(self) -> None: + class RuleBuilderWorld(CachedRuleBuilderWorld): + game = GAME_NAME + item_name_to_id: ClassVar = {f"Item {i}": i for i in range(1, LOC_COUNT + 1)} + location_name_to_id: ClassVar = {f"Location {i}": i for i in range(1, LOC_COUNT + 1)} + item_name_groups: ClassVar = { + "Group 1": {"Item 1", "Item 2", "Item 3"}, + "Group 2": {"Item 4", "Item 5"}, + } + hidden = True + options_dataclass = RuleBuilderOptions + options: RuleBuilderOptions # pyright: ignore[reportIncompatibleVariableOverride] + origin_region_name = "Region 1" + + @override + def create_item(self, name: str) -> RuleBuilderItem: + classification = ItemClassification.filler if name == "Filler" else ItemClassification.progression + return RuleBuilderItem(name, classification, self.item_name_to_id[name], self.player) + + @override + def get_filler_item_name(self) -> str: + return "Filler" + + self.world_cls = RuleBuilderWorld + + +@classvar_matrix( + rules=( + ( + And(Has("A", 1), Has("A", 2)), + Has.Resolved("A", 2, player=1), + ), + ( + And(Has("A"), HasAll("B", "C")), + HasAll.Resolved(("A", "B", "C"), player=1), + ), + ( + Or(Has("A", 1), Has("A", 2)), + Has.Resolved("A", 1, player=1), + ), + ( + Or(Has("A"), HasAny("B", "C")), + HasAny.Resolved(("A", "B", "C"), player=1), + ), + ( + Or(HasAll("A"), HasAll("A", "A")), + Has.Resolved("A", player=1), + ), + ( + Or( + Has("A"), + Or( + True_(options=[OptionFilter(ChoiceOption, 0)]), + HasAny("B", "C", options=[OptionFilter(ChoiceOption, 0, "gt")]), + options=[OptionFilter(ToggleOption, 1)], + ), + And(Has("D"), Has("E"), options=[OptionFilter(ToggleOption, 0)]), + Has("F"), + ), + Or.Resolved( + ( + HasAll.Resolved(("D", "E"), player=1), + HasAny.Resolved(("A", "F"), player=1), + ), + player=1, + ), + ), + ( + Or( + Has("A"), + Or( + True_(options=[OptionFilter(ChoiceOption, 0, "gt")]), + HasAny("B", "C", options=[OptionFilter(ChoiceOption, 0)]), + options=[OptionFilter(ToggleOption, 0)], + ), + And(Has("D"), Has("E"), options=[OptionFilter(ToggleOption, 1)]), + Has("F"), + ), + HasAny.Resolved(("A", "B", "C", "F"), player=1), + ), + ( + And(Has("A"), True_()), + Has.Resolved("A", player=1), + ), + ( + And(Has("A"), False_()), + False_.Resolved(player=1), + ), + ( + Or(Has("A"), True_()), + True_.Resolved(player=1), + ), + ( + Or(Has("A"), False_()), + Has.Resolved("A", player=1), + ), + ( + And(Has("A"), HasAll("B", "C"), HasAllCounts({"D": 2, "E": 3})), + HasAllCounts.Resolved((("A", 1), ("B", 1), ("C", 1), ("D", 2), ("E", 3)), player=1), + ), + ( + And(Has("A"), HasAll("B", "C"), HasAllCounts({"D": 1, "E": 1})), + HasAll.Resolved(("A", "B", "C", "D", "E"), player=1), + ), + ( + Or(Has("A"), HasAny("B", "C"), HasAnyCount({"D": 2, "E": 3})), + HasAnyCount.Resolved((("A", 1), ("B", 1), ("C", 1), ("D", 2), ("E", 3)), player=1), + ), + ( + Or(Has("A"), HasAny("B", "C"), HasAnyCount({"D": 1, "E": 1})), + HasAny.Resolved(("A", "B", "C", "D", "E"), player=1), + ), + ) +) +class TestSimplify(RuleBuilderTestCase): + rules: ClassVar[tuple[Rule[Any], Rule.Resolved]] + + def test_simplify(self) -> None: + multiworld = setup_solo_multiworld(self.world_cls, steps=("generate_early",), seed=0) + world = multiworld.worlds[1] + rule, expected = self.rules + resolved_rule = rule.resolve(world) + self.assertEqual(resolved_rule, expected, f"\n{resolved_rule}\n{expected}") + + +@classvar_matrix( + cases=( + (ToggleOption, 0, "eq", 0, True), + (ToggleOption, 0, "eq", 1, False), + (ToggleOption, 0, "ne", 0, False), + (ToggleOption, 0, "ne", 1, True), + (ChoiceOption, 0, "gt", 1, False), + (ChoiceOption, 1, "gt", 1, False), + (ChoiceOption, 2, "gt", 1, True), + (ChoiceOption, 0, "ge", "second", False), + (ChoiceOption, 1, "ge", "second", True), + (ChoiceOption, 1, "in", (0, 1), True), + (ChoiceOption, 1, "in", ("first", "second"), True), + (FreeTextOption, "no", "eq", "yes", False), + (FreeTextOption, "yes", "eq", "yes", True), + (SetOption, ("one", "two"), "contains", "three", False), + (SetOption, ("one", "two"), "contains", "two", True), + ) +) +class TestOptions(RuleBuilderTestCase): + cases: ClassVar[tuple[type[Option[Any]], Any, Operator, Any, bool]] + + def test_option_resolution(self) -> None: + multiworld = setup_solo_multiworld(self.world_cls, steps=("generate_early",), seed=0) + world = multiworld.worlds[1] + option_cls, world_value, operator, filter_value, expected = self.cases + + for field in fields(world.options_dataclass): + if field.type is option_cls: + setattr(world.options, field.name, option_cls.from_any(world_value)) + break + + option_filter = OptionFilter(option_cls, filter_value, operator) + result = option_filter.check(world.options) + self.assertEqual(result, expected, f"Expected {result} for option={option_filter} with value={world_value}") + + +class TestFilteredResolution(RuleBuilderTestCase): + def test_filtered_resolution(self) -> None: + multiworld = setup_solo_multiworld(self.world_cls, steps=("generate_early",), seed=0) + world = multiworld.worlds[1] + + rule_and_false = Has("A") & Has("B", options=[OptionFilter(ToggleOption, 1)], filtered_resolution=False) + rule_and_true = Has("A") & Has("B", options=[OptionFilter(ToggleOption, 1)], filtered_resolution=True) + rule_or_false = Has("A") | Has("B", options=[OptionFilter(ToggleOption, 1)], filtered_resolution=False) + rule_or_true = Has("A") | Has("B", options=[OptionFilter(ToggleOption, 1)], filtered_resolution=True) + + # option fails check + world.options.toggle_option.value = 0 # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + self.assertEqual(rule_and_false.resolve(world), False_.Resolved(player=1)) + self.assertEqual(rule_and_true.resolve(world), Has.Resolved("A", player=1)) + self.assertEqual(rule_or_false.resolve(world), Has.Resolved("A", player=1)) + self.assertEqual(rule_or_true.resolve(world), True_.Resolved(player=1)) + + # option passes check + world.options.toggle_option.value = 1 # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + self.assertEqual(rule_and_false.resolve(world), HasAll.Resolved(("A", "B"), player=1)) + self.assertEqual(rule_and_true.resolve(world), HasAll.Resolved(("A", "B"), player=1)) + self.assertEqual(rule_or_false.resolve(world), HasAny.Resolved(("A", "B"), player=1)) + self.assertEqual(rule_or_true.resolve(world), HasAny.Resolved(("A", "B"), player=1)) + + +@classvar_matrix( + rules=( + ( + Has("A") & Has("B"), + And(Has("A"), Has("B")), + ), + ( + Has("A") | Has("B"), + Or(Has("A"), Has("B")), + ), + ( + And(Has("A")) & Has("B"), + And(Has("A"), Has("B")), + ), + ( + And(Has("A"), Has("B")) & And(Has("C")), + And(Has("A"), Has("B"), Has("C")), + ), + ( + And(Has("A"), Has("B")) | Or(Has("C"), Has("D")), + Or(And(Has("A"), Has("B")), Has("C"), Has("D")), + ), + ( + Or(Has("A")) | Or(Has("B"), options=[OptionFilter(ToggleOption, 1)]), + Or(Or(Has("A")), Or(Has("B"), options=[OptionFilter(ToggleOption, 1)])), + ), + ( + ( + And(Has("A"), options=[OptionFilter(ToggleOption, 1)]) + & And(Has("B"), options=[OptionFilter(ToggleOption, 1)]) + ), + And(Has("A"), Has("B"), options=[OptionFilter(ToggleOption, 1)]), + ), + ( + Has("A") & Has("B") & Has("C"), + And(Has("A"), Has("B"), Has("C")), + ), + ( + Has("A") & Has("B") | Has("C") & Has("D"), + Or(And(Has("A"), Has("B")), And(Has("C"), Has("D"))), + ), + ( + Has("A") | Or(Has("B"), options=[OptionFilter(ToggleOption, 1)]), + Or(Has("A"), Or(Has("B"), options=[OptionFilter(ToggleOption, 1)])), + ), + ( + Has("A") & And(Has("B"), options=[OptionFilter(ToggleOption, 1)]), + And(Has("A"), And(Has("B"), options=[OptionFilter(ToggleOption, 1)])), + ), + ( + (Has("A") | Has("B")) & [OptionFilter(ToggleOption, 1)], + Filtered(Or(Has("A"), Has("B")), options=[OptionFilter(ToggleOption, 1)]), + ), + ( + (Has("A") | Has("B")) | OptionFilter(ToggleOption, 1), + Or(Or(Has("A"), Has("B")), True_(options=[OptionFilter(ToggleOption, 1)])), + ), + ( + OptionFilter(ToggleOption, 1) & (Has("A") | Has("B")), + Filtered(Or(Has("A"), Has("B")), options=[OptionFilter(ToggleOption, 1)]), + ), + ( + [OptionFilter(ToggleOption, 1)] | (Has("A") | Has("B")), + Or(Or(Has("A"), Has("B")), True_(options=[OptionFilter(ToggleOption, 1)])), + ), + ) +) +class TestComposition(unittest.TestCase): + rules: ClassVar[tuple[Rule[Any], Rule[Any]]] + + def test_composition(self) -> None: + combined_rule, expected = self.rules + self.assertEqual(combined_rule, expected, str(combined_rule)) + + +class TestHashes(RuleBuilderTestCase): + def test_and_hash(self) -> None: + rule1 = And.Resolved((True_.Resolved(player=1),), player=1) + rule2 = And.Resolved((True_.Resolved(player=1),), player=1) + rule3 = Or.Resolved((True_.Resolved(player=1),), player=1) + + self.assertEqual(hash(rule1), hash(rule2)) + self.assertNotEqual(hash(rule1), hash(rule3)) + + def test_has_all_hash(self) -> None: + multiworld = setup_solo_multiworld(self.world_cls, steps=("generate_early",), seed=0) + world = multiworld.worlds[1] + rule1 = HasAll("1", "2") + rule2 = HasAll("2", "2", "2", "1") + self.assertEqual(hash(rule1.resolve(world)), hash(rule2.resolve(world))) + + +class TestCaching(CachedRuleBuilderTestCase): + multiworld: MultiWorld # pyright: ignore[reportUninitializedInstanceVariable] + world: World # pyright: ignore[reportUninitializedInstanceVariable] + state: CachedCollectionState # pyright: ignore[reportUninitializedInstanceVariable] + player: int = 1 + + @override + def setUp(self) -> None: + super().setUp() + + self.multiworld = setup_solo_multiworld(self.world_cls, seed=0) + world = self.multiworld.worlds[1] + self.world = world + self.state = cast(CachedCollectionState, self.multiworld.state) + + region1 = Region("Region 1", self.player, self.multiworld) + region2 = Region("Region 2", self.player, self.multiworld) + region3 = Region("Region 3", self.player, self.multiworld) + self.multiworld.regions.extend([region1, region2, region3]) + + region1.add_locations({"Location 1": 1, "Location 2": 2, "Location 6": 6}, RuleBuilderLocation) + region2.add_locations({"Location 3": 3, "Location 4": 4}, RuleBuilderLocation) + region3.add_locations({"Location 5": 5}, RuleBuilderLocation) + + world.create_entrance(region1, region2, Has("Item 1")) + world.create_entrance(region1, region3, HasAny("Item 3", "Item 4")) + world.set_rule(world.get_location("Location 2"), CanReachRegion("Region 2") & Has("Item 2")) + world.set_rule(world.get_location("Location 4"), HasAll("Item 2", "Item 3")) + world.set_rule(world.get_location("Location 5"), CanReachLocation("Location 4")) + world.set_rule(world.get_location("Location 6"), CanReachEntrance("Region 1 -> Region 2") & Has("Item 2")) + + for i in range(1, LOC_COUNT + 1): + self.multiworld.itempool.append(world.create_item(f"Item {i}")) + + world.register_rule_builder_dependencies() + + def test_item_cache_busting(self) -> None: + location = self.world.get_location("Location 4") + self.state.collect(self.world.create_item("Item 1")) # access to region 2 + self.state.collect(self.world.create_item("Item 2")) # item directly needed + self.assertFalse(location.can_reach(self.state)) # populates cache + self.assertFalse(self.state.rule_builder_cache[1][id(location.access_rule)]) + + self.state.collect(self.world.create_item("Item 3")) # clears cache, item directly needed + self.assertNotIn(id(location.access_rule), self.state.rule_builder_cache[1]) + self.assertTrue(location.can_reach(self.state)) + self.assertTrue(self.state.rule_builder_cache[1][id(location.access_rule)]) + self.state.collect(self.world.create_item("Item 3")) # does not clear cache as rule is already true + self.assertTrue(self.state.rule_builder_cache[1][id(location.access_rule)]) + + def test_region_cache_busting(self) -> None: + location = self.world.get_location("Location 2") + self.state.collect(self.world.create_item("Item 2")) # item directly needed for location rule + self.assertFalse(location.can_reach(self.state)) # populates cache + self.assertFalse(self.state.rule_builder_cache[1][id(location.access_rule)]) + + self.state.collect(self.world.create_item("Item 1")) # clears cache, item only needed for region 2 access + # cache gets cleared during the can_reach + self.assertTrue(location.can_reach(self.state)) + self.assertTrue(self.state.rule_builder_cache[1][id(location.access_rule)]) + + def test_location_cache_busting(self) -> None: + location = self.world.get_location("Location 5") + self.state.collect(self.world.create_item("Item 1")) # access to region 2 + self.state.collect(self.world.create_item("Item 3")) # access to region 3 + self.assertFalse(location.can_reach(self.state)) # populates cache + self.assertFalse(self.state.rule_builder_cache[1][id(location.access_rule)]) + + self.state.collect(self.world.create_item("Item 2")) # clears cache, item only needed for location 2 access + self.assertNotIn(id(location.access_rule), self.state.rule_builder_cache[1]) + self.assertTrue(location.can_reach(self.state)) + + def test_entrance_cache_busting(self) -> None: + location = self.world.get_location("Location 6") + self.state.collect(self.world.create_item("Item 2")) # item directly needed for location rule + self.assertFalse(location.can_reach(self.state)) # populates cache + self.assertFalse(self.state.rule_builder_cache[1][id(location.access_rule)]) + + self.state.collect(self.world.create_item("Item 1")) # clears cache, item only needed for entrance access + self.assertNotIn(id(location.access_rule), self.state.rule_builder_cache[1]) + self.assertTrue(location.can_reach(self.state)) + + def test_has_skips_cache(self) -> None: + entrance = self.world.get_entrance("Region 1 -> Region 2") + self.assertFalse(entrance.can_reach(self.state)) # does not populates cache + self.assertNotIn(id(entrance.access_rule), self.state.rule_builder_cache[1]) + + self.state.collect(self.world.create_item("Item 1")) # no need to clear cache, item directly needed + self.assertNotIn(id(entrance.access_rule), self.state.rule_builder_cache[1]) + self.assertTrue(entrance.can_reach(self.state)) + + +class TestCacheDisabled(RuleBuilderTestCase): + multiworld: MultiWorld # pyright: ignore[reportUninitializedInstanceVariable] + world: World # pyright: ignore[reportUninitializedInstanceVariable] + state: CachedCollectionState # pyright: ignore[reportUninitializedInstanceVariable] + player: int = 1 + + @override + def setUp(self) -> None: + super().setUp() + + self.multiworld = setup_solo_multiworld(self.world_cls, seed=0) + world = self.multiworld.worlds[1] + self.world = world + self.state = cast(CachedCollectionState, self.multiworld.state) + + region1 = Region("Region 1", self.player, self.multiworld) + region2 = Region("Region 2", self.player, self.multiworld) + region3 = Region("Region 3", self.player, self.multiworld) + self.multiworld.regions.extend([region1, region2, region3]) + + region1.add_locations({"Location 1": 1, "Location 2": 2, "Location 6": 6}, RuleBuilderLocation) + region2.add_locations({"Location 3": 3, "Location 4": 4}, RuleBuilderLocation) + region3.add_locations({"Location 5": 5}, RuleBuilderLocation) + + world.create_entrance(region1, region2, Has("Item 1")) + world.create_entrance(region1, region3, HasAny("Item 3", "Item 4")) + world.set_rule(world.get_location("Location 2"), CanReachRegion("Region 2") & Has("Item 2")) + world.set_rule(world.get_location("Location 4"), HasAll("Item 2", "Item 3")) + world.set_rule(world.get_location("Location 5"), CanReachLocation("Location 4")) + world.set_rule(world.get_location("Location 6"), CanReachEntrance("Region 1 -> Region 2") & Has("Item 2")) + + for i in range(1, LOC_COUNT + 1): + self.multiworld.itempool.append(world.create_item(f"Item {i}")) + + def test_item_logic(self) -> None: + entrance = self.world.get_entrance("Region 1 -> Region 2") + self.assertFalse(entrance.can_reach(self.state)) + self.assertFalse(self.state.rule_builder_cache[1]) + + self.state.collect(self.world.create_item("Item 1")) # item directly needed + self.assertFalse(self.state.rule_builder_cache[1]) + self.assertTrue(entrance.can_reach(self.state)) + + def test_region_logic(self) -> None: + location = self.world.get_location("Location 2") + self.state.collect(self.world.create_item("Item 2")) # item directly needed for location rule + self.assertFalse(location.can_reach(self.state)) + self.assertFalse(self.state.rule_builder_cache[1]) + + self.state.collect(self.world.create_item("Item 1")) # item only needed for region 2 access + self.assertTrue(location.can_reach(self.state)) + self.assertFalse(self.state.rule_builder_cache[1]) + + def test_location_logic(self) -> None: + location = self.world.get_location("Location 5") + self.state.collect(self.world.create_item("Item 1")) # access to region 2 + self.state.collect(self.world.create_item("Item 3")) # access to region 3 + self.assertFalse(location.can_reach(self.state)) + self.assertFalse(self.state.rule_builder_cache[1]) + + self.state.collect(self.world.create_item("Item 2")) # item only needed for location 2 access + self.assertFalse(self.state.rule_builder_cache[1]) + self.assertTrue(location.can_reach(self.state)) + + def test_entrance_logic(self) -> None: + location = self.world.get_location("Location 6") + self.state.collect(self.world.create_item("Item 2")) # item directly needed for location rule + self.assertFalse(location.can_reach(self.state)) + self.assertFalse(self.state.rule_builder_cache[1]) + + self.state.collect(self.world.create_item("Item 1")) # item only needed for entrance access + self.assertFalse(self.state.rule_builder_cache[1]) + self.assertTrue(location.can_reach(self.state)) + + +class TestRules(RuleBuilderTestCase): + multiworld: MultiWorld # pyright: ignore[reportUninitializedInstanceVariable] + world: World # pyright: ignore[reportUninitializedInstanceVariable] + state: CollectionState # pyright: ignore[reportUninitializedInstanceVariable] + player: int = 1 + + @override + def setUp(self) -> None: + super().setUp() + + self.multiworld = setup_solo_multiworld(self.world_cls, seed=0) + world = self.multiworld.worlds[1] + self.world = world + self.state = self.multiworld.state + + def test_true(self) -> None: + rule = True_() + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertTrue(resolved_rule(self.state)) + + def test_false(self) -> None: + rule = False_() + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertFalse(resolved_rule(self.state)) + + def test_has(self) -> None: + rule = Has("Item 1") + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertFalse(resolved_rule(self.state)) + item = self.world.create_item("Item 1") + self.state.collect(item) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(item) + self.assertFalse(resolved_rule(self.state)) + + def test_has_all(self) -> None: + rule = HasAll("Item 1", "Item 2") + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertFalse(resolved_rule(self.state)) + item1 = self.world.create_item("Item 1") + self.state.collect(item1) + self.assertFalse(resolved_rule(self.state)) + item2 = self.world.create_item("Item 2") + self.state.collect(item2) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(item1) + self.assertFalse(resolved_rule(self.state)) + + def test_has_any(self) -> None: + item_names = ("Item 1", "Item 2") + rule = HasAny(*item_names) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertFalse(resolved_rule(self.state)) + + for item_name in item_names: + item = self.world.create_item(item_name) + self.state.collect(item) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(item) + self.assertFalse(resolved_rule(self.state)) + + def test_has_all_counts(self) -> None: + rule = HasAllCounts({"Item 1": 1, "Item 2": 2}) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertFalse(resolved_rule(self.state)) + item1 = self.world.create_item("Item 1") + self.state.collect(item1) + self.assertFalse(resolved_rule(self.state)) + item2 = self.world.create_item("Item 2") + self.state.collect(item2) + self.assertFalse(resolved_rule(self.state)) + item2 = self.world.create_item("Item 2") + self.state.collect(item2) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(item2) + self.assertFalse(resolved_rule(self.state)) + + def test_has_any_count(self) -> None: + item_counts = {"Item 1": 1, "Item 2": 2} + rule = HasAnyCount(item_counts) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + + for item_name, count in item_counts.items(): + item = self.world.create_item(item_name) + for _ in range(count): + self.assertFalse(resolved_rule(self.state)) + self.state.collect(item) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(item) + self.assertFalse(resolved_rule(self.state)) + + def test_has_from_list(self) -> None: + item_names = ("Item 1", "Item 2", "Item 3") + rule = HasFromList(*item_names, count=2) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertFalse(resolved_rule(self.state)) + + items: list[Item] = [] + for i, item_name in enumerate(item_names): + item = self.world.create_item(item_name) + self.state.collect(item) + items.append(item) + if i == 0: + self.assertFalse(resolved_rule(self.state)) + else: + self.assertTrue(resolved_rule(self.state)) + + for i in range(2): + self.state.remove(items[i]) + self.assertFalse(resolved_rule(self.state)) + + def test_has_from_list_unique(self) -> None: + item_names = ("Item 1", "Item 1", "Item 2") + rule = HasFromListUnique(*item_names, count=2) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + self.assertFalse(resolved_rule(self.state)) + + items: list[Item] = [] + for i, item_name in enumerate(item_names): + item = self.world.create_item(item_name) + self.state.collect(item) + items.append(item) + if i < 2: + self.assertFalse(resolved_rule(self.state)) + else: + self.assertTrue(resolved_rule(self.state)) + + self.state.remove(items[0]) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(items[1]) + self.assertFalse(resolved_rule(self.state)) + + def test_has_group(self) -> None: + rule = HasGroup("Group 1", count=2) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + + items: list[Item] = [] + for item_name in ("Item 1", "Item 2"): + self.assertFalse(resolved_rule(self.state)) + item = self.world.create_item(item_name) + self.state.collect(item) + items.append(item) + + self.assertTrue(resolved_rule(self.state)) + self.state.remove(items[0]) + self.assertFalse(resolved_rule(self.state)) + + def test_has_group_unique(self) -> None: + rule = HasGroupUnique("Group 1", count=2) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + + items: list[Item] = [] + for item_name in ("Item 1", "Item 1", "Item 2"): + self.assertFalse(resolved_rule(self.state)) + item = self.world.create_item(item_name) + self.state.collect(item) + items.append(item) + + self.assertTrue(resolved_rule(self.state)) + self.state.remove(items[0]) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(items[1]) + self.assertFalse(resolved_rule(self.state)) + + def test_completion_rule(self) -> None: + rule = Has("Item 1") + self.world.set_completion_rule(rule) + self.assertEqual(self.multiworld.can_beat_game(self.state), False) + self.state.collect(self.world.create_item("Item 1")) + self.assertEqual(self.multiworld.can_beat_game(self.state), True) + + +class TestSerialization(RuleBuilderTestCase): + maxDiff: int | None = None + + rule: ClassVar[Rule[Any]] = And( + Or( + Has("i1", count=4), + HasFromList("i2", "i3", "i4", count=2), + HasAnyCount({"i5": 2, "i6": 3}), + options=[OptionFilter(ToggleOption, 0)], + ), + Or( + HasAll("i7", "i8"), + HasAllCounts( + {"i9": 1, "i10": 5}, + options=[OptionFilter(ToggleOption, 1, operator="ne")], + filtered_resolution=True, + ), + CanReachRegion("r1"), + HasGroup("g1"), + ), + And( + HasAny("i11", "i12"), + CanReachLocation("l1", "r2"), + HasFromListUnique("i13", "i14"), + options=[ + OptionFilter(ToggleOption, ToggleOption.option_false), + OptionFilter(ChoiceOption, ChoiceOption.option_second, "ge"), + ], + ), + CanReachEntrance("e1"), + HasGroupUnique("g2", count=5), + ) + + rule_dict: ClassVar[dict[str, Any]] = { + "rule": "And", + "options": [], + "filtered_resolution": False, + "children": [ + { + "rule": "Or", + "options": [ + { + "option": "test.general.test_rule_builder.ToggleOption", + "value": 0, + "operator": "eq", + }, + ], + "filtered_resolution": False, + "children": [ + { + "rule": "Has", + "options": [], + "filtered_resolution": False, + "args": {"item_name": "i1", "count": 4}, + }, + { + "rule": "HasFromList", + "options": [], + "filtered_resolution": False, + "args": {"item_names": ("i2", "i3", "i4"), "count": 2}, + }, + { + "rule": "HasAnyCount", + "options": [], + "filtered_resolution": False, + "args": {"item_counts": {"i5": 2, "i6": 3}}, + }, + ], + }, + { + "rule": "Or", + "options": [], + "filtered_resolution": False, + "children": [ + { + "rule": "HasAll", + "options": [], + "filtered_resolution": False, + "args": {"item_names": ("i7", "i8")}, + }, + { + "rule": "HasAllCounts", + "options": [ + { + "option": "test.general.test_rule_builder.ToggleOption", + "value": 1, + "operator": "ne", + }, + ], + "filtered_resolution": True, + "args": {"item_counts": {"i9": 1, "i10": 5}}, + }, + { + "rule": "CanReachRegion", + "options": [], + "filtered_resolution": False, + "args": {"region_name": "r1"}, + }, + { + "rule": "HasGroup", + "options": [], + "filtered_resolution": False, + "args": {"item_name_group": "g1", "count": 1}, + }, + ], + }, + { + "rule": "And", + "options": [ + { + "option": "test.general.test_rule_builder.ToggleOption", + "value": 0, + "operator": "eq", + }, + { + "option": "test.general.test_rule_builder.ChoiceOption", + "value": 1, + "operator": "ge", + }, + ], + "filtered_resolution": False, + "children": [ + { + "rule": "HasAny", + "options": [], + "filtered_resolution": False, + "args": {"item_names": ("i11", "i12")}, + }, + { + "rule": "CanReachLocation", + "options": [], + "filtered_resolution": False, + "args": {"location_name": "l1", "parent_region_name": "r2", "skip_indirect_connection": False}, + }, + { + "rule": "HasFromListUnique", + "options": [], + "filtered_resolution": False, + "args": {"item_names": ("i13", "i14"), "count": 1}, + }, + ], + }, + { + "rule": "CanReachEntrance", + "options": [], + "filtered_resolution": False, + "args": {"entrance_name": "e1", "parent_region_name": ""}, + }, + { + "rule": "HasGroupUnique", + "options": [], + "filtered_resolution": False, + "args": {"item_name_group": "g2", "count": 5}, + }, + ], + } + + def test_serialize(self) -> None: + serialized_rule = self.rule.to_dict() + self.assertDictEqual(serialized_rule, self.rule_dict) + + def test_deserialize(self) -> None: + multiworld = setup_solo_multiworld(self.world_cls, steps=(), seed=0) + world = multiworld.worlds[1] + deserialized_rule = world.rule_from_dict(self.rule_dict) + self.assertEqual(deserialized_rule, self.rule, str(deserialized_rule)) + + +class TestExplain(RuleBuilderTestCase): + multiworld: MultiWorld # pyright: ignore[reportUninitializedInstanceVariable] + world: World # pyright: ignore[reportUninitializedInstanceVariable] + state: CollectionState # pyright: ignore[reportUninitializedInstanceVariable] + player: int = 1 + + resolved_rule: ClassVar[Rule.Resolved] = And.Resolved( + ( + Or.Resolved( + ( + Has.Resolved("Item 1", count=4, player=1), + HasAll.Resolved(("Item 2", "Item 3"), player=1), + HasAny.Resolved(("Item 4", "Item 5"), player=1), + ), + player=1, + ), + HasAllCounts.Resolved((("Item 6", 1), ("Item 7", 5)), player=1), + HasAnyCount.Resolved((("Item 8", 2), ("Item 9", 3)), player=1), + HasFromList.Resolved(("Item 10", "Item 11", "Item 12"), count=2, player=1), + HasFromListUnique.Resolved(("Item 13", "Item 14"), player=1), + HasGroup.Resolved("Group 1", ("Item 15", "Item 16", "Item 17"), player=1), + HasGroupUnique.Resolved("Group 2", ("Item 18", "Item 19"), count=2, player=1), + CanReachRegion.Resolved("Region 2", player=1), + CanReachLocation.Resolved("Location 2", "Region 2", player=1), + CanReachEntrance.Resolved("Entrance 2", "Region 2", player=1), + True_.Resolved(player=1), + False_.Resolved(player=1), + ), + player=1, + ) + + @override + def setUp(self) -> None: + super().setUp() + + self.multiworld = setup_solo_multiworld(self.world_cls, seed=0) + world = self.multiworld.worlds[1] + self.world = world + self.state = self.multiworld.state + + region1 = Region("Region 1", self.player, self.multiworld) + region2 = Region("Region 2", self.player, self.multiworld) + region3 = Region("Region 3", self.player, self.multiworld) + self.multiworld.regions.extend([region1, region2, region3]) + + region2.add_locations({"Location 2": 1}, RuleBuilderLocation) + world.create_entrance(region1, region2, Has("Item 1")) + world.create_entrance(region2, region3, name="Entrance 2") + + def _collect_all(self) -> None: + for i in range(1, LOC_COUNT + 1): + for _ in range(10): + item = self.world.create_item(f"Item {i}") + self.state.collect(item) + + def test_explain_json_with_state_no_items(self) -> None: + expected: list[JSONMessagePart] = [ + {"type": "text", "text": "("}, + {"type": "text", "text": "("}, + {"type": "text", "text": "Missing "}, + {"type": "color", "color": "cyan", "text": "4"}, + {"type": "text", "text": "x "}, + {"type": "color", "color": "salmon", "text": "Item 1"}, + {"type": "text", "text": " | "}, + {"type": "text", "text": "Missing "}, + {"type": "color", "color": "cyan", "text": "some"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Missing: "}, + {"type": "color", "color": "salmon", "text": "Item 2"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "salmon", "text": "Item 3"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " | "}, + {"type": "text", "text": "Missing "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Missing: "}, + {"type": "color", "color": "salmon", "text": "Item 4"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "salmon", "text": "Item 5"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Missing "}, + {"type": "color", "color": "cyan", "text": "some"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Missing: "}, + {"type": "color", "color": "salmon", "text": "Item 6"}, + {"type": "text", "text": " x1"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "salmon", "text": "Item 7"}, + {"type": "text", "text": " x5"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Missing "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Missing: "}, + {"type": "color", "color": "salmon", "text": "Item 8"}, + {"type": "text", "text": " x2"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "salmon", "text": "Item 9"}, + {"type": "text", "text": " x3"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "salmon", "text": "0/2"}, + {"type": "text", "text": " items from ("}, + {"type": "text", "text": "Missing: "}, + {"type": "color", "color": "salmon", "text": "Item 10"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "salmon", "text": "Item 11"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "salmon", "text": "Item 12"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "salmon", "text": "0/1"}, + {"type": "text", "text": " unique items from ("}, + {"type": "text", "text": "Missing: "}, + {"type": "color", "color": "salmon", "text": "Item 13"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "salmon", "text": "Item 14"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "salmon", "text": "0/1"}, + {"type": "text", "text": " items from "}, + {"type": "color", "color": "cyan", "text": "Group 1"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "salmon", "text": "0/2"}, + {"type": "text", "text": " unique items from "}, + {"type": "color", "color": "cyan", "text": "Group 2"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Cannot reach region "}, + {"type": "color", "color": "yellow", "text": "Region 2"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Cannot reach location "}, + {"type": "location_name", "text": "Location 2", "player": 1}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Cannot reach entrance "}, + {"type": "entrance_name", "text": "Entrance 2", "player": 1}, + {"type": "text", "text": " & "}, + {"type": "color", "color": "green", "text": "True"}, + {"type": "text", "text": " & "}, + {"type": "color", "color": "salmon", "text": "False"}, + {"type": "text", "text": ")"}, + ] + assert self.resolved_rule.explain_json(self.state) == expected + + def test_explain_json_with_state_all_items(self) -> None: + self._collect_all() + + expected: list[JSONMessagePart] = [ + {"type": "text", "text": "("}, + {"type": "text", "text": "("}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "4"}, + {"type": "text", "text": "x "}, + {"type": "color", "color": "green", "text": "Item 1"}, + {"type": "text", "text": " | "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Found: "}, + {"type": "color", "color": "green", "text": "Item 2"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "green", "text": "Item 3"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " | "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "some"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Found: "}, + {"type": "color", "color": "green", "text": "Item 4"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "green", "text": "Item 5"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Found: "}, + {"type": "color", "color": "green", "text": "Item 6"}, + {"type": "text", "text": " x1"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "green", "text": "Item 7"}, + {"type": "text", "text": " x5"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "some"}, + {"type": "text", "text": " of ("}, + {"type": "text", "text": "Found: "}, + {"type": "color", "color": "green", "text": "Item 8"}, + {"type": "text", "text": " x2"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "green", "text": "Item 9"}, + {"type": "text", "text": " x3"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "green", "text": "30/2"}, + {"type": "text", "text": " items from ("}, + {"type": "text", "text": "Found: "}, + {"type": "color", "color": "green", "text": "Item 10"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "green", "text": "Item 11"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "green", "text": "Item 12"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "green", "text": "2/1"}, + {"type": "text", "text": " unique items from ("}, + {"type": "text", "text": "Found: "}, + {"type": "color", "color": "green", "text": "Item 13"}, + {"type": "text", "text": ", "}, + {"type": "color", "color": "green", "text": "Item 14"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "green", "text": "30/1"}, + {"type": "text", "text": " items from "}, + {"type": "color", "color": "cyan", "text": "Group 1"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "green", "text": "2/2"}, + {"type": "text", "text": " unique items from "}, + {"type": "color", "color": "cyan", "text": "Group 2"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Reached region "}, + {"type": "color", "color": "yellow", "text": "Region 2"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Reached location "}, + {"type": "location_name", "text": "Location 2", "player": 1}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Reached entrance "}, + {"type": "entrance_name", "text": "Entrance 2", "player": 1}, + {"type": "text", "text": " & "}, + {"type": "color", "color": "green", "text": "True"}, + {"type": "text", "text": " & "}, + {"type": "color", "color": "salmon", "text": "False"}, + {"type": "text", "text": ")"}, + ] + assert self.resolved_rule.explain_json(self.state) == expected + + def test_explain_json_without_state(self) -> None: + expected: list[JSONMessagePart] = [ + {"type": "text", "text": "("}, + {"type": "text", "text": "("}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "4"}, + {"type": "text", "text": "x "}, + {"type": "item_name", "flags": 1, "text": "Item 1", "player": 1}, + {"type": "text", "text": " | "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + {"type": "item_name", "flags": 1, "text": "Item 2", "player": 1}, + {"type": "text", "text": ", "}, + {"type": "item_name", "flags": 1, "text": "Item 3", "player": 1}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " | "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "any"}, + {"type": "text", "text": " of ("}, + {"type": "item_name", "flags": 1, "text": "Item 4", "player": 1}, + {"type": "text", "text": ", "}, + {"type": "item_name", "flags": 1, "text": "Item 5", "player": 1}, + {"type": "text", "text": ")"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "all"}, + {"type": "text", "text": " of ("}, + {"type": "item_name", "flags": 1, "text": "Item 6", "player": 1}, + {"type": "text", "text": " x1"}, + {"type": "text", "text": ", "}, + {"type": "item_name", "flags": 1, "text": "Item 7", "player": 1}, + {"type": "text", "text": " x5"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "any"}, + {"type": "text", "text": " of ("}, + {"type": "item_name", "flags": 1, "text": "Item 8", "player": 1}, + {"type": "text", "text": " x2"}, + {"type": "text", "text": ", "}, + {"type": "item_name", "flags": 1, "text": "Item 9", "player": 1}, + {"type": "text", "text": " x3"}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "2"}, + {"type": "text", "text": "x items from ("}, + {"type": "item_name", "flags": 1, "text": "Item 10", "player": 1}, + {"type": "text", "text": ", "}, + {"type": "item_name", "flags": 1, "text": "Item 11", "player": 1}, + {"type": "text", "text": ", "}, + {"type": "item_name", "flags": 1, "text": "Item 12", "player": 1}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "1"}, + {"type": "text", "text": "x unique items from ("}, + {"type": "item_name", "flags": 1, "text": "Item 13", "player": 1}, + {"type": "text", "text": ", "}, + {"type": "item_name", "flags": 1, "text": "Item 14", "player": 1}, + {"type": "text", "text": ")"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "1"}, + {"type": "text", "text": " items from "}, + {"type": "color", "color": "cyan", "text": "Group 1"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Has "}, + {"type": "color", "color": "cyan", "text": "2"}, + {"type": "text", "text": " unique items from "}, + {"type": "color", "color": "cyan", "text": "Group 2"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Can reach region "}, + {"type": "color", "color": "yellow", "text": "Region 2"}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Can reach location "}, + {"type": "location_name", "text": "Location 2", "player": 1}, + {"type": "text", "text": " & "}, + {"type": "text", "text": "Can reach entrance "}, + {"type": "entrance_name", "text": "Entrance 2", "player": 1}, + {"type": "text", "text": " & "}, + {"type": "color", "color": "green", "text": "True"}, + {"type": "text", "text": " & "}, + {"type": "color", "color": "salmon", "text": "False"}, + {"type": "text", "text": ")"}, + ] + assert self.resolved_rule.explain_json() == expected + + def test_explain_str_with_state_no_items(self) -> None: + expected = ( + "((Missing 4x Item 1", + "| Missing some of (Missing: Item 2, Item 3)", + "| Missing all of (Missing: Item 4, Item 5))", + "& Missing some of (Missing: Item 6 x1, Item 7 x5)", + "& Missing all of (Missing: Item 8 x2, Item 9 x3)", + "& Has 0/2 items from (Missing: Item 10, Item 11, Item 12)", + "& Has 0/1 unique items from (Missing: Item 13, Item 14)", + "& Has 0/1 items from Group 1", + "& Has 0/2 unique items from Group 2", + "& Cannot reach region Region 2", + "& Cannot reach location Location 2", + "& Cannot reach entrance Entrance 2", + "& True", + "& False)", + ) + assert self.resolved_rule.explain_str(self.state) == " ".join(expected) + + def test_explain_str_with_state_all_items(self) -> None: + self._collect_all() + + expected = ( + "((Has 4x Item 1", + "| Has all of (Found: Item 2, Item 3)", + "| Has some of (Found: Item 4, Item 5))", + "& Has all of (Found: Item 6 x1, Item 7 x5)", + "& Has some of (Found: Item 8 x2, Item 9 x3)", + "& Has 30/2 items from (Found: Item 10, Item 11, Item 12)", + "& Has 2/1 unique items from (Found: Item 13, Item 14)", + "& Has 30/1 items from Group 1", + "& Has 2/2 unique items from Group 2", + "& Reached region Region 2", + "& Reached location Location 2", + "& Reached entrance Entrance 2", + "& True", + "& False)", + ) + assert self.resolved_rule.explain_str(self.state) == " ".join(expected) + + def test_explain_str_without_state(self) -> None: + expected = ( + "((Has 4x Item 1", + "| Has all of (Item 2, Item 3)", + "| Has any of (Item 4, Item 5))", + "& Has all of (Item 6 x1, Item 7 x5)", + "& Has any of (Item 8 x2, Item 9 x3)", + "& Has 2x items from (Item 10, Item 11, Item 12)", + "& Has a unique item from (Item 13, Item 14)", + "& Has an item from Group 1", + "& Has 2x unique items from Group 2", + "& Can reach region Region 2", + "& Can reach location Location 2", + "& Can reach entrance Entrance 2", + "& True", + "& False)", + ) + assert self.resolved_rule.explain_str() == " ".join(expected) + + def test_str(self) -> None: + expected = ( + "((Has 4x Item 1", + "| Has all of (Item 2, Item 3)", + "| Has any of (Item 4, Item 5))", + "& Has all of (Item 6 x1, Item 7 x5)", + "& Has any of (Item 8 x2, Item 9 x3)", + "& Has 2x items from (Item 10, Item 11, Item 12)", + "& Has a unique item from (Item 13, Item 14)", + "& Has an item from Group 1", + "& Has 2x unique items from Group 2", + "& Can reach region Region 2", + "& Can reach location Location 2", + "& Can reach entrance Entrance 2", + "& True", + "& False)", + ) + assert str(self.resolved_rule) == " ".join(expected) diff --git a/worlds/AutoWorld.py b/worlds/AutoWorld.py index 9d57f4f534..327e386c05 100644 --- a/worlds/AutoWorld.py +++ b/worlds/AutoWorld.py @@ -5,17 +5,18 @@ import logging import pathlib import sys import time +from collections.abc import Callable, Iterable, Mapping from random import Random -from dataclasses import make_dataclass -from typing import (Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Mapping, Optional, Set, TextIO, Tuple, +from typing import (Any, ClassVar, Dict, FrozenSet, List, Optional, Self, Set, TextIO, Tuple, TYPE_CHECKING, Type, Union) from Options import item_and_loc_options, ItemsAccessibility, OptionGroup, PerGameCommonOptions -from BaseClasses import CollectionState +from BaseClasses import CollectionState, Entrance +from rule_builder.rules import CustomRuleRegister, Rule from Utils import Version if TYPE_CHECKING: - from BaseClasses import MultiWorld, Item, Location, Tutorial, Region, Entrance + from BaseClasses import CollectionRule, Item, Location, MultiWorld, Region, Tutorial from NetUtils import GamesPackage, MultiData from settings import Group @@ -177,7 +178,8 @@ def _timed_call(method: Callable[..., Any], *args: Any, def call_single(multiworld: "MultiWorld", method_name: str, player: int, *args: Any) -> Any: - method = getattr(multiworld.worlds[player], method_name) + world = multiworld.worlds[player] + method = getattr(world, method_name) try: ret = _timed_call(method, *args, multiworld=multiworld, player=player) except Exception as e: @@ -188,6 +190,10 @@ def call_single(multiworld: "MultiWorld", method_name: str, player: int, *args: logging.error(message) raise e else: + # Convenience for CachedRuleBuilderWorld users: Ensure that caching setup function is called + # Can be removed once dependency system is improved + if method_name == "set_rules" and hasattr(world, "register_rule_builder_dependencies"): + call_single(multiworld, "register_rule_builder_dependencies", player) return ret @@ -549,6 +555,10 @@ class World(metaclass=AutoWorldRegister): return True return False + def reached_region(self, state: "CollectionState", region: "Region") -> None: + """Called when a region is newly reachable by the state.""" + pass + # following methods should not need to be overridden. def create_filler(self) -> "Item": return self.create_item(self.get_filler_item_name()) @@ -597,6 +607,64 @@ class World(metaclass=AutoWorldRegister): res["checksum"] = data_package_checksum(res) return res + @classmethod + def get_rule_cls(cls, name: str) -> type[Rule[Self]]: + """Returns the world-registered or default rule with the given name""" + return CustomRuleRegister.get_rule_cls(cls.game, name) + + @classmethod + def rule_from_dict(cls, data: Mapping[str, Any]) -> Rule[Self]: + """Create a rule instance from a serialized dict representation""" + name = data.get("rule", "") + rule_class = cls.get_rule_cls(name) + return rule_class.from_dict(data, cls) + + def set_rule(self, spot: Location | Entrance, rule: CollectionRule | Rule[Any]) -> None: + """Sets an access rule for a location or entrance""" + if isinstance(rule, Rule): + rule = rule.resolve(self) + self.register_rule_dependencies(rule) + if isinstance(spot, Entrance): + self._register_rule_indirects(rule, spot) + spot.access_rule = rule + + def set_completion_rule(self, rule: CollectionRule | Rule[Any]) -> None: + """Set the completion rule for this world""" + if isinstance(rule, Rule): + rule = rule.resolve(self) + self.register_rule_dependencies(rule) + self.multiworld.completion_condition[self.player] = rule + + def create_entrance( + self, + from_region: Region, + to_region: Region, + rule: CollectionRule | Rule[Any] | None = None, + name: str | None = None, + force_creation: bool = False, + ) -> Entrance | None: + """Try to create an entrance between regions with the given rule, + skipping it if the rule resolves to False (unless force_creation is True)""" + if rule is not None and isinstance(rule, Rule): + rule = rule.resolve(self) + if rule.always_false and not force_creation: + return None + self.register_rule_dependencies(rule) + + entrance = from_region.connect(to_region, name, rule=rule) + if rule and isinstance(rule, Rule.Resolved): + self._register_rule_indirects(rule, entrance) + return entrance + + def register_rule_dependencies(self, resolved_rule: Rule.Resolved) -> None: + """Hook for registering dependencies when a rule is assigned for this world""" + pass + + def _register_rule_indirects(self, resolved_rule: Rule.Resolved, entrance: Entrance) -> None: + if self.explicit_indirect_conditions: + for indirect_region in resolved_rule.region_dependencies().keys(): + self.multiworld.register_indirect_condition(self.get_region(indirect_region), entrance) + # any methods attached to this can be used as part of CollectionState, # please use a prefix as all of them get clobbered together diff --git a/worlds/generic/Rules.py b/worlds/generic/Rules.py index 31d725bff7..bfb79bbc29 100644 --- a/worlds/generic/Rules.py +++ b/worlds/generic/Rules.py @@ -2,16 +2,10 @@ import collections import logging import typing -from BaseClasses import LocationProgressType, MultiWorld, Location, Region, Entrance +from BaseClasses import (CollectionRule, CollectionState, Entrance, Item, Location, + LocationProgressType, MultiWorld, Region) -if typing.TYPE_CHECKING: - import BaseClasses - - CollectionRule = typing.Callable[[BaseClasses.CollectionState], bool] - ItemRule = typing.Callable[[BaseClasses.Item], bool] -else: - CollectionRule = typing.Callable[[object], bool] - ItemRule = typing.Callable[[object], bool] +ItemRule = typing.Callable[[Item], bool] def locality_needed(multiworld: MultiWorld) -> bool: @@ -96,11 +90,11 @@ def exclusion_rules(multiworld: MultiWorld, player: int, exclude_locations: typi logging.warning(f"Unable to exclude location {loc_name} in player {player}'s world.") -def set_rule(spot: typing.Union["BaseClasses.Location", "BaseClasses.Entrance"], rule: CollectionRule): +def set_rule(spot: typing.Union[Location, Entrance], rule: CollectionRule): spot.access_rule = rule -def add_rule(spot: typing.Union["BaseClasses.Location", "BaseClasses.Entrance"], rule: CollectionRule, combine="and"): +def add_rule(spot: typing.Union[Location, Entrance], rule: CollectionRule, combine="and"): old_rule = spot.access_rule # empty rule, replace instead of add if old_rule is Location.access_rule or old_rule is Entrance.access_rule: @@ -112,7 +106,7 @@ def add_rule(spot: typing.Union["BaseClasses.Location", "BaseClasses.Entrance"], spot.access_rule = lambda state: rule(state) or old_rule(state) -def forbid_item(location: "BaseClasses.Location", item: str, player: int): +def forbid_item(location: Location, item: str, player: int): old_rule = location.item_rule # empty rule if old_rule is Location.item_rule: @@ -121,18 +115,18 @@ def forbid_item(location: "BaseClasses.Location", item: str, player: int): location.item_rule = lambda i: (i.name != item or i.player != player) and old_rule(i) -def forbid_items_for_player(location: "BaseClasses.Location", items: typing.Set[str], player: int): +def forbid_items_for_player(location: Location, items: typing.Set[str], player: int): old_rule = location.item_rule location.item_rule = lambda i: (i.player != player or i.name not in items) and old_rule(i) -def forbid_items(location: "BaseClasses.Location", items: typing.Set[str]): +def forbid_items(location: Location, items: typing.Set[str]): """unused, but kept as a debugging tool.""" old_rule = location.item_rule location.item_rule = lambda i: i.name not in items and old_rule(i) -def add_item_rule(location: "BaseClasses.Location", rule: ItemRule, combine: str = "and"): +def add_item_rule(location: Location, rule: ItemRule, combine: str = "and"): old_rule = location.item_rule # empty rule, replace instead of add if old_rule is Location.item_rule: @@ -144,7 +138,7 @@ def add_item_rule(location: "BaseClasses.Location", rule: ItemRule, combine: str location.item_rule = lambda item: rule(item) or old_rule(item) -def item_name_in_location_names(state: "BaseClasses.CollectionState", item: str, player: int, +def item_name_in_location_names(state: CollectionState, item: str, player: int, location_name_player_pairs: typing.Sequence[typing.Tuple[str, int]]) -> bool: for location in location_name_player_pairs: if location_item_name(state, location[0], location[1]) == (item, player): @@ -153,14 +147,14 @@ def item_name_in_location_names(state: "BaseClasses.CollectionState", item: str, def item_name_in_locations(item: str, player: int, - locations: typing.Sequence["BaseClasses.Location"]) -> bool: + locations: typing.Sequence[Location]) -> bool: for location in locations: if location.item and location.item.name == item and location.item.player == player: return True return False -def location_item_name(state: "BaseClasses.CollectionState", location: str, player: int) -> \ +def location_item_name(state: CollectionState, location: str, player: int) -> \ typing.Optional[typing.Tuple[str, int]]: location = state.multiworld.get_location(location, player) if location.item is None: