Rule Builder: Add field resolvers (#5919)

This commit is contained in:
Ian Robinson
2026-03-30 12:19:10 -04:00
committed by GitHub
parent 58a6407040
commit c640d2fa24
5 changed files with 341 additions and 38 deletions

View File

@@ -3,6 +3,7 @@
"../BizHawkClient.py",
"../Patch.py",
"../rule_builder/cached_world.py",
"../rule_builder/field_resolvers.py",
"../rule_builder/options.py",
"../rule_builder/rules.py",
"../test/param.py",

View File

@@ -129,6 +129,42 @@ common_rule_only_on_easy = common_rule & easy_filter
common_rule_skipped_on_easy = common_rule | easy_filter
```
### Field resolvers
When creating rules you may sometimes need to set a field to a value that depends on the world instance. You can use a `FieldResolver` to define how to populate that field when the rule is being resolved.
There are two build-in field resolvers:
- `FromOption`: Resolves to the value of the given option
- `FromWorldAttr`: Resolves to the value of the given world instance attribute, can specify a dotted path `a.b.c` to get a nested attribute or dict item
```python
world.options.mcguffin_count = 5
world.precalculated_value = 99
rule = (
Has("A", count=FromOption(McguffinCount))
| HasGroup("Important items", count=FromWorldAttr("precalculated_value"))
)
# Results in Has("A", count=5) | HasGroup("Important items", count=99)
```
You can define your own resolvers by creating a class that inherits from `FieldResolver`, provides your game name, and implements a `resolve` function:
```python
@dataclasses.dataclass(frozen=True)
class FromCustomResolution(FieldResolver, game="MyGame"):
modifier: str
@override
def resolve(self, world: "World") -> Any:
return some_math_calculation(world, self.modifier)
rule = Has("Combat Level", count=FromCustomResolution("combat"))
```
If you want to support rule serialization and your resolver contains non-serializable properties you may need to override `to_dict` or `from_dict`.
## Enabling caching
The rule builder provides a `CachedRuleBuilderWorld` base class for your `World` class that enables caching on your rules.

View File

@@ -0,0 +1,162 @@
import dataclasses
import importlib
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeVar, cast, overload
from typing_extensions import override
from Options import Option
if TYPE_CHECKING:
from worlds.AutoWorld import World
class FieldResolverRegister:
"""A container class to contain world custom resolvers"""
custom_resolvers: ClassVar[dict[str, dict[str, type["FieldResolver"]]]] = {}
"""
A mapping of game name to mapping of resolver name to resolver class
to hold custom resolvers implemented by worlds
"""
@classmethod
def get_resolver_cls(cls, game_name: str, resolver_name: str) -> type["FieldResolver"]:
"""Returns the world-registered or default resolver with the given name"""
custom_resolver_classes = cls.custom_resolvers.get(game_name, {})
if resolver_name not in DEFAULT_RESOLVERS and resolver_name not in custom_resolver_classes:
raise ValueError(f"Resolver '{resolver_name}' for game '{game_name}' not found")
return custom_resolver_classes.get(resolver_name) or DEFAULT_RESOLVERS[resolver_name]
@dataclasses.dataclass(frozen=True)
class FieldResolver(ABC):
@abstractmethod
def resolve(self, world: "World") -> Any: ...
def to_dict(self) -> dict[str, Any]:
"""Returns a JSON compatible dict representation of this resolver"""
fields = {field.name: getattr(self, field.name, None) for field in dataclasses.fields(self)}
return {
"resolver": self.__class__.__name__,
**fields,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
"""Returns a new instance of this resolver from a serialized dict representation"""
assert data.get("resolver", None) == cls.__name__
return cls(**{k: v for k, v in data.items() if k != "resolver"})
@override
def __str__(self) -> str:
return self.__class__.__name__
@classmethod
def __init_subclass__(cls, /, game: str) -> None:
if game != "Archipelago":
custom_resolvers = FieldResolverRegister.custom_resolvers.setdefault(game, {})
if cls.__qualname__ in custom_resolvers:
raise TypeError(f"Resolver {cls.__qualname__} has already been registered for game {game}")
custom_resolvers[cls.__qualname__] = cls
elif cls.__module__ != "rule_builder.field_resolvers":
raise TypeError("You cannot define custom resolvers for the base Archipelago world")
@dataclasses.dataclass(frozen=True)
class FromOption(FieldResolver, game="Archipelago"):
option: type[Option[Any]]
field: str = "value"
@override
def resolve(self, world: "World") -> Any:
option_name = next(
(name for name, cls in world.options.__class__.type_hints.items() if cls is self.option),
None,
)
if option_name is None:
raise ValueError(
f"Cannot find option {self.option.__name__} in options class {world.options.__class__.__name__}"
)
opt = cast(Option[Any] | None, getattr(world.options, option_name, None))
if opt is None:
raise ValueError(f"Invalid option: {option_name}")
return getattr(opt, self.field)
@override
def to_dict(self) -> dict[str, Any]:
return {
"resolver": "FromOption",
"option": f"{self.option.__module__}.{self.option.__name__}",
"field": self.field,
}
@override
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
if "option" not in data:
raise ValueError("Missing required option")
option_path = data["option"]
try:
option_mod_name, option_cls_name = option_path.rsplit(".", 1)
option_module = importlib.import_module(option_mod_name)
option = getattr(option_module, option_cls_name, None)
except (ValueError, ImportError) as e:
raise ValueError(f"Cannot parse option '{option_path}'") from e
if option is None or not issubclass(option, Option):
raise ValueError(f"Invalid option '{option_path}' returns type '{option}' instead of Option subclass")
return cls(cast(type[Option[Any]], option), data.get("field", "value"))
@override
def __str__(self) -> str:
field = f".{self.field}" if self.field != "value" else ""
return f"FromOption({self.option.__name__}{field})"
@dataclasses.dataclass(frozen=True)
class FromWorldAttr(FieldResolver, game="Archipelago"):
name: str
@override
def resolve(self, world: "World") -> Any:
obj: Any = world
for field in self.name.split("."):
if obj is None:
return None
if isinstance(obj, Mapping):
obj = obj.get(field, None) # pyright: ignore[reportUnknownMemberType]
else:
obj = getattr(obj, field, None)
return obj
@override
def __str__(self) -> str:
return f"FromWorldAttr({self.name})"
T = TypeVar("T")
@overload
def resolve_field(field: Any, world: "World", expected_type: type[T]) -> T: ...
@overload
def resolve_field(field: Any, world: "World", expected_type: None = None) -> Any: ...
def resolve_field(field: Any, world: "World", expected_type: type[T] | None = None) -> T | Any:
if isinstance(field, FieldResolver):
field = field.resolve(world)
if expected_type:
assert isinstance(field, expected_type), f"Expected type {expected_type} but got {type(field)}"
return field
DEFAULT_RESOLVERS = {
resolver_name: resolver_class
for resolver_name, resolver_class in locals().items()
if isinstance(resolver_class, type)
and issubclass(resolver_class, FieldResolver)
and resolver_class is not FieldResolver
}

View File

@@ -7,6 +7,7 @@ from typing_extensions import TypeVar, dataclass_transform, override
from BaseClasses import CollectionState
from NetUtils import JSONMessagePart
from .field_resolvers import FieldResolver, FieldResolverRegister, resolve_field
from .options import OptionFilter
if TYPE_CHECKING:
@@ -108,11 +109,14 @@ class Rule(Generic[TWorld]):
def to_dict(self) -> dict[str, Any]:
"""Returns a JSON compatible dict representation of this rule"""
args = {
field.name: getattr(self, field.name, None)
for field in dataclasses.fields(self)
if field.name not in ("options", "filtered_resolution")
}
args = {}
for field in dataclasses.fields(self):
if field.name in ("options", "filtered_resolution"):
continue
value = getattr(self, field.name, None)
if isinstance(value, FieldResolver):
value = value.to_dict()
args[field.name] = value
return {
"rule": self.__class__.__qualname__,
"options": [o.to_dict() for o in self.options],
@@ -124,7 +128,19 @@ class Rule(Generic[TWorld]):
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
"""Returns a new instance of this rule from a serialized dict representation"""
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(**data.get("args", {}), options=options, filtered_resolution=data.get("filtered_resolution", False))
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
return cls(**args, options=options, filtered_resolution=data.get("filtered_resolution", False))
@classmethod
def _parse_field_resolvers(cls, data: Mapping[str, Any], game_name: str) -> dict[str, Any]:
result: dict[str, Any] = {}
for name, value in data.items():
if isinstance(value, dict) and "resolver" in value:
resolver_cls = FieldResolverRegister.get_resolver_cls(game_name, value["resolver"]) # pyright: ignore[reportUnknownArgumentType]
result[name] = resolver_cls.from_dict(value) # pyright: ignore[reportUnknownArgumentType]
else:
result[name] = value
return result
def __and__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]":
"""Combines two rules or a rule and an option filter into an And rule"""
@@ -688,24 +704,24 @@ class Filtered(WrapperRule[TWorld], game="Archipelago"):
class Has(Rule[TWorld], game="Archipelago"):
"""A rule that checks if the player has at least `count` of a given item"""
item_name: str
item_name: str | FieldResolver
"""The item to check for"""
count: int = 1
count: int | FieldResolver = 1
"""The count the player is required to have"""
@override
def _instantiate(self, world: TWorld) -> Rule.Resolved:
return self.Resolved(
self.item_name,
self.count,
resolve_field(self.item_name, world, str),
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def __str__(self) -> str:
count = f", count={self.count}" if self.count > 1 else ""
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
options = f", options={self.options}" if self.options else ""
return f"{self.__class__.__name__}({self.item_name}{count}{options})"
@@ -991,7 +1007,7 @@ class HasAny(Rule[TWorld], game="Archipelago"):
class HasAllCounts(Rule[TWorld], game="Archipelago"):
"""A rule that checks if the player has all of the specified counts of the given items"""
item_counts: dict[str, int]
item_counts: Mapping[str, int | FieldResolver]
"""A mapping of item name to count to check for"""
@override
@@ -1002,12 +1018,30 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"):
if len(self.item_counts) == 1:
item = next(iter(self.item_counts))
return Has(item, self.item_counts[item]).resolve(world)
item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items())
return self.Resolved(
tuple(self.item_counts.items()),
item_counts,
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
output["args"]["item_counts"] = {
key: value.to_dict() if isinstance(value, FieldResolver) else value
for key, value in output["args"]["item_counts"].items()
}
return output
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = data.get("args", {})
item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game)
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False))
@override
def __str__(self) -> str:
items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()])
@@ -1096,7 +1130,7 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"):
class HasAnyCount(Rule[TWorld], game="Archipelago"):
"""A rule that checks if the player has any of the specified counts of the given items"""
item_counts: dict[str, int]
item_counts: Mapping[str, int | FieldResolver]
"""A mapping of item name to count to check for"""
@override
@@ -1107,12 +1141,30 @@ class HasAnyCount(Rule[TWorld], game="Archipelago"):
if len(self.item_counts) == 1:
item = next(iter(self.item_counts))
return Has(item, self.item_counts[item]).resolve(world)
item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items())
return self.Resolved(
tuple(self.item_counts.items()),
item_counts,
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
output["args"]["item_counts"] = {
key: value.to_dict() if isinstance(value, FieldResolver) else value
for key, value in output["args"]["item_counts"].items()
}
return output
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = data.get("args", {})
item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game)
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False))
@override
def __str__(self) -> str:
items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()])
@@ -1204,13 +1256,13 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
item_names: tuple[str, ...]
"""A tuple of item names to check for"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
def __init__(
self,
*item_names: str,
count: int = 1,
count: int | FieldResolver = 1,
options: Iterable[OptionFilter] = (),
filtered_resolution: bool = False,
) -> None:
@@ -1227,7 +1279,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
return Has(self.item_names[0], self.count).resolve(world)
return self.Resolved(
self.item_names,
self.count,
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@@ -1235,7 +1287,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = {**data.get("args", {})}
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
item_names = args.pop("item_names", ())
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False))
@@ -1338,13 +1390,13 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
item_names: tuple[str, ...]
"""A tuple of item names to check for"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
def __init__(
self,
*item_names: str,
count: int = 1,
count: int | FieldResolver = 1,
options: Iterable[OptionFilter] = (),
filtered_resolution: bool = False,
) -> None:
@@ -1354,14 +1406,15 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
@override
def _instantiate(self, world: TWorld) -> Rule.Resolved:
if len(self.item_names) == 0 or len(self.item_names) < self.count:
count = resolve_field(self.count, world, int)
if len(self.item_names) == 0 or len(self.item_names) < count:
# match state.has_from_list_unique
return False_().resolve(world)
if len(self.item_names) == 1:
return Has(self.item_names[0]).resolve(world)
return self.Resolved(
self.item_names,
self.count,
count,
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@@ -1369,7 +1422,7 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = {**data.get("args", {})}
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
item_names = args.pop("item_names", ())
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False))
@@ -1468,7 +1521,7 @@ class HasGroup(Rule[TWorld], game="Archipelago"):
item_name_group: str
"""The name of the item group containing the items"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
@override
@@ -1477,14 +1530,14 @@ class HasGroup(Rule[TWorld], game="Archipelago"):
return self.Resolved(
self.item_name_group,
item_names,
self.count,
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def __str__(self) -> str:
count = f", count={self.count}" if self.count > 1 else ""
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
options = f", options={self.options}" if self.options else ""
return f"{self.__class__.__name__}({self.item_name_group}{count}{options})"
@@ -1542,7 +1595,7 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"):
item_name_group: str
"""The name of the item group containing the items"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
@override
@@ -1551,14 +1604,14 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"):
return self.Resolved(
self.item_name_group,
item_names,
self.count,
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def __str__(self) -> str:
count = f", count={self.count}" if self.count > 1 else ""
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
options = f", options={self.options}" if self.options else ""
return f"{self.__class__.__name__}({self.item_name_group}{count}{options})"

View File

@@ -6,8 +6,9 @@ from typing_extensions import override
from BaseClasses import CollectionState, Item, ItemClassification, Location, MultiWorld, Region
from NetUtils import JSONMessagePart
from Options import Choice, FreeText, Option, OptionSet, PerGameCommonOptions, Toggle
from Options import Choice, FreeText, Option, OptionSet, PerGameCommonOptions, Range, Toggle
from rule_builder.cached_world import CachedRuleBuilderWorld
from rule_builder.field_resolvers import FieldResolver, FromOption, FromWorldAttr, resolve_field
from rule_builder.options import Operator, OptionFilter
from rule_builder.rules import (
And,
@@ -59,12 +60,20 @@ class SetOption(OptionSet):
valid_keys: ClassVar[set[str]] = {"one", "two", "three"} # pyright: ignore[reportIncompatibleVariableOverride]
class RangeOption(Range):
auto_display_name = True
range_start = 1
range_end = 10
default = 5
@dataclass
class RuleBuilderOptions(PerGameCommonOptions):
toggle_option: ToggleOption
choice_option: ChoiceOption
text_option: FreeTextOption
set_option: SetOption
range_option: RangeOption
GAME_NAME = "Rule Builder Test Game"
@@ -659,14 +668,15 @@ class TestRules(RuleBuilderTestCase):
self.assertFalse(resolved_rule(self.state))
def test_has_any_count(self) -> None:
item_counts = {"Item 1": 1, "Item 2": 2}
item_counts: dict[str, int | FieldResolver] = {"Item 1": 1, "Item 2": 2}
rule = HasAnyCount(item_counts)
resolved_rule = rule.resolve(self.world)
self.world.register_rule_dependencies(resolved_rule)
for item_name, count in item_counts.items():
item = self.world.create_item(item_name)
for _ in range(count):
num_items = resolve_field(count, self.world, int)
for _ in range(num_items):
self.assertFalse(resolved_rule(self.state))
self.state.collect(item)
self.assertTrue(resolved_rule(self.state))
@@ -763,7 +773,7 @@ class TestSerialization(RuleBuilderTestCase):
rule: ClassVar[Rule[Any]] = And(
Or(
Has("i1", count=4),
Has("i1", count=FromOption(RangeOption)),
HasFromList("i2", "i3", "i4", count=2),
HasAnyCount({"i5": 2, "i6": 3}),
options=[OptionFilter(ToggleOption, 0)],
@@ -771,7 +781,7 @@ class TestSerialization(RuleBuilderTestCase):
Or(
HasAll("i7", "i8"),
HasAllCounts(
{"i9": 1, "i10": 5},
{"i9": 1, "i10": FromWorldAttr("instance_data.i10_count")},
options=[OptionFilter(ToggleOption, 1, operator="ne")],
filtered_resolution=True,
),
@@ -811,7 +821,14 @@ class TestSerialization(RuleBuilderTestCase):
"rule": "Has",
"options": [],
"filtered_resolution": False,
"args": {"item_name": "i1", "count": 4},
"args": {
"item_name": "i1",
"count": {
"resolver": "FromOption",
"option": "test.general.test_rule_builder.RangeOption",
"field": "value",
},
},
},
{
"rule": "HasFromList",
@@ -848,7 +865,12 @@ class TestSerialization(RuleBuilderTestCase):
},
],
"filtered_resolution": True,
"args": {"item_counts": {"i9": 1, "i10": 5}},
"args": {
"item_counts": {
"i9": 1,
"i10": {"resolver": "FromWorldAttr", "name": "instance_data.i10_count"},
}
},
},
{
"rule": "CanReachRegion",
@@ -923,7 +945,7 @@ class TestSerialization(RuleBuilderTestCase):
multiworld = setup_solo_multiworld(self.world_cls, steps=(), seed=0)
world = multiworld.worlds[1]
deserialized_rule = world.rule_from_dict(self.rule_dict)
self.assertEqual(deserialized_rule, self.rule, str(deserialized_rule))
self.assertEqual(deserialized_rule, self.rule, f"\n{deserialized_rule}\n{self.rule}")
class TestExplain(RuleBuilderTestCase):
@@ -1342,3 +1364,32 @@ class TestExplain(RuleBuilderTestCase):
"& False)",
)
assert str(self.resolved_rule) == " ".join(expected)
@classvar_matrix(
rules=(
(
Has("A", FromOption(RangeOption)),
Has.Resolved("A", count=5, player=1),
),
(
Has("B", FromWorldAttr("pre_calculated")),
Has.Resolved("B", count=3, player=1),
),
(
Has("C", FromWorldAttr("instance_data.key")),
Has.Resolved("C", count=7, player=1),
),
)
)
class TestFieldResolvers(RuleBuilderTestCase):
rules: ClassVar[tuple[Rule[Any], Rule.Resolved]]
def test_simplify(self) -> None:
multiworld = setup_solo_multiworld(self.world_cls, steps=("generate_early",), seed=0)
world = multiworld.worlds[1]
world.pre_calculated = 3 # pyright: ignore[reportAttributeAccessIssue]
world.instance_data = {"key": 7} # pyright: ignore[reportAttributeAccessIssue]
rule, expected = self.rules
resolved_rule = rule.resolve(world)
self.assertEqual(resolved_rule, expected, f"\n{resolved_rule}\n{expected}")