diff --git a/.github/pyright-config.json b/.github/pyright-config.json index fba044da06..c5432dbf3c 100644 --- a/.github/pyright-config.json +++ b/.github/pyright-config.json @@ -3,6 +3,7 @@ "../BizHawkClient.py", "../Patch.py", "../rule_builder/cached_world.py", + "../rule_builder/field_resolvers.py", "../rule_builder/options.py", "../rule_builder/rules.py", "../test/param.py", diff --git a/docs/rule builder.md b/docs/rule builder.md index 4f9102a2ba..c3a8fcb6c4 100644 --- a/docs/rule builder.md +++ b/docs/rule builder.md @@ -129,6 +129,42 @@ common_rule_only_on_easy = common_rule & easy_filter common_rule_skipped_on_easy = common_rule | easy_filter ``` +### Field resolvers + +When creating rules you may sometimes need to set a field to a value that depends on the world instance. You can use a `FieldResolver` to define how to populate that field when the rule is being resolved. + +There are two build-in field resolvers: + +- `FromOption`: Resolves to the value of the given option +- `FromWorldAttr`: Resolves to the value of the given world instance attribute, can specify a dotted path `a.b.c` to get a nested attribute or dict item + +```python +world.options.mcguffin_count = 5 +world.precalculated_value = 99 +rule = ( + Has("A", count=FromOption(McguffinCount)) + | HasGroup("Important items", count=FromWorldAttr("precalculated_value")) +) +# Results in Has("A", count=5) | HasGroup("Important items", count=99) +``` + +You can define your own resolvers by creating a class that inherits from `FieldResolver`, provides your game name, and implements a `resolve` function: + +```python +@dataclasses.dataclass(frozen=True) +class FromCustomResolution(FieldResolver, game="MyGame"): + modifier: str + + @override + def resolve(self, world: "World") -> Any: + return some_math_calculation(world, self.modifier) + + +rule = Has("Combat Level", count=FromCustomResolution("combat")) +``` + +If you want to support rule serialization and your resolver contains non-serializable properties you may need to override `to_dict` or `from_dict`. + ## Enabling caching The rule builder provides a `CachedRuleBuilderWorld` base class for your `World` class that enables caching on your rules. diff --git a/rule_builder/field_resolvers.py b/rule_builder/field_resolvers.py new file mode 100644 index 0000000000..1e5def6b44 --- /dev/null +++ b/rule_builder/field_resolvers.py @@ -0,0 +1,162 @@ +import dataclasses +import importlib +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeVar, cast, overload + +from typing_extensions import override + +from Options import Option + +if TYPE_CHECKING: + from worlds.AutoWorld import World + + +class FieldResolverRegister: + """A container class to contain world custom resolvers""" + + custom_resolvers: ClassVar[dict[str, dict[str, type["FieldResolver"]]]] = {} + """ + A mapping of game name to mapping of resolver name to resolver class + to hold custom resolvers implemented by worlds + """ + + @classmethod + def get_resolver_cls(cls, game_name: str, resolver_name: str) -> type["FieldResolver"]: + """Returns the world-registered or default resolver with the given name""" + custom_resolver_classes = cls.custom_resolvers.get(game_name, {}) + if resolver_name not in DEFAULT_RESOLVERS and resolver_name not in custom_resolver_classes: + raise ValueError(f"Resolver '{resolver_name}' for game '{game_name}' not found") + return custom_resolver_classes.get(resolver_name) or DEFAULT_RESOLVERS[resolver_name] + + +@dataclasses.dataclass(frozen=True) +class FieldResolver(ABC): + @abstractmethod + def resolve(self, world: "World") -> Any: ... + + def to_dict(self) -> dict[str, Any]: + """Returns a JSON compatible dict representation of this resolver""" + fields = {field.name: getattr(self, field.name, None) for field in dataclasses.fields(self)} + return { + "resolver": self.__class__.__name__, + **fields, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Self: + """Returns a new instance of this resolver from a serialized dict representation""" + assert data.get("resolver", None) == cls.__name__ + return cls(**{k: v for k, v in data.items() if k != "resolver"}) + + @override + def __str__(self) -> str: + return self.__class__.__name__ + + @classmethod + def __init_subclass__(cls, /, game: str) -> None: + if game != "Archipelago": + custom_resolvers = FieldResolverRegister.custom_resolvers.setdefault(game, {}) + if cls.__qualname__ in custom_resolvers: + raise TypeError(f"Resolver {cls.__qualname__} has already been registered for game {game}") + custom_resolvers[cls.__qualname__] = cls + elif cls.__module__ != "rule_builder.field_resolvers": + raise TypeError("You cannot define custom resolvers for the base Archipelago world") + + +@dataclasses.dataclass(frozen=True) +class FromOption(FieldResolver, game="Archipelago"): + option: type[Option[Any]] + field: str = "value" + + @override + def resolve(self, world: "World") -> Any: + option_name = next( + (name for name, cls in world.options.__class__.type_hints.items() if cls is self.option), + None, + ) + + if option_name is None: + raise ValueError( + f"Cannot find option {self.option.__name__} in options class {world.options.__class__.__name__}" + ) + opt = cast(Option[Any] | None, getattr(world.options, option_name, None)) + if opt is None: + raise ValueError(f"Invalid option: {option_name}") + return getattr(opt, self.field) + + @override + def to_dict(self) -> dict[str, Any]: + return { + "resolver": "FromOption", + "option": f"{self.option.__module__}.{self.option.__name__}", + "field": self.field, + } + + @override + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Self: + if "option" not in data: + raise ValueError("Missing required option") + + option_path = data["option"] + try: + option_mod_name, option_cls_name = option_path.rsplit(".", 1) + option_module = importlib.import_module(option_mod_name) + option = getattr(option_module, option_cls_name, None) + except (ValueError, ImportError) as e: + raise ValueError(f"Cannot parse option '{option_path}'") from e + if option is None or not issubclass(option, Option): + raise ValueError(f"Invalid option '{option_path}' returns type '{option}' instead of Option subclass") + + return cls(cast(type[Option[Any]], option), data.get("field", "value")) + + @override + def __str__(self) -> str: + field = f".{self.field}" if self.field != "value" else "" + return f"FromOption({self.option.__name__}{field})" + + +@dataclasses.dataclass(frozen=True) +class FromWorldAttr(FieldResolver, game="Archipelago"): + name: str + + @override + def resolve(self, world: "World") -> Any: + obj: Any = world + for field in self.name.split("."): + if obj is None: + return None + if isinstance(obj, Mapping): + obj = obj.get(field, None) # pyright: ignore[reportUnknownMemberType] + else: + obj = getattr(obj, field, None) + return obj + + @override + def __str__(self) -> str: + return f"FromWorldAttr({self.name})" + + +T = TypeVar("T") + + +@overload +def resolve_field(field: Any, world: "World", expected_type: type[T]) -> T: ... +@overload +def resolve_field(field: Any, world: "World", expected_type: None = None) -> Any: ... +def resolve_field(field: Any, world: "World", expected_type: type[T] | None = None) -> T | Any: + if isinstance(field, FieldResolver): + field = field.resolve(world) + if expected_type: + assert isinstance(field, expected_type), f"Expected type {expected_type} but got {type(field)}" + return field + + +DEFAULT_RESOLVERS = { + resolver_name: resolver_class + for resolver_name, resolver_class in locals().items() + if isinstance(resolver_class, type) + and issubclass(resolver_class, FieldResolver) + and resolver_class is not FieldResolver +} diff --git a/rule_builder/rules.py b/rule_builder/rules.py index 77a89c96c2..07c0607c1f 100644 --- a/rule_builder/rules.py +++ b/rule_builder/rules.py @@ -7,6 +7,7 @@ from typing_extensions import TypeVar, dataclass_transform, override from BaseClasses import CollectionState from NetUtils import JSONMessagePart +from .field_resolvers import FieldResolver, FieldResolverRegister, resolve_field from .options import OptionFilter if TYPE_CHECKING: @@ -108,11 +109,14 @@ class Rule(Generic[TWorld]): def to_dict(self) -> dict[str, Any]: """Returns a JSON compatible dict representation of this rule""" - args = { - field.name: getattr(self, field.name, None) - for field in dataclasses.fields(self) - if field.name not in ("options", "filtered_resolution") - } + args = {} + for field in dataclasses.fields(self): + if field.name in ("options", "filtered_resolution"): + continue + value = getattr(self, field.name, None) + if isinstance(value, FieldResolver): + value = value.to_dict() + args[field.name] = value return { "rule": self.__class__.__qualname__, "options": [o.to_dict() for o in self.options], @@ -124,7 +128,19 @@ class Rule(Generic[TWorld]): def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: """Returns a new instance of this rule from a serialized dict representation""" options = OptionFilter.multiple_from_dict(data.get("options", ())) - return cls(**data.get("args", {}), options=options, filtered_resolution=data.get("filtered_resolution", False)) + args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game) + return cls(**args, options=options, filtered_resolution=data.get("filtered_resolution", False)) + + @classmethod + def _parse_field_resolvers(cls, data: Mapping[str, Any], game_name: str) -> dict[str, Any]: + result: dict[str, Any] = {} + for name, value in data.items(): + if isinstance(value, dict) and "resolver" in value: + resolver_cls = FieldResolverRegister.get_resolver_cls(game_name, value["resolver"]) # pyright: ignore[reportUnknownArgumentType] + result[name] = resolver_cls.from_dict(value) # pyright: ignore[reportUnknownArgumentType] + else: + result[name] = value + return result def __and__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]": """Combines two rules or a rule and an option filter into an And rule""" @@ -688,24 +704,24 @@ class Filtered(WrapperRule[TWorld], game="Archipelago"): class Has(Rule[TWorld], game="Archipelago"): """A rule that checks if the player has at least `count` of a given item""" - item_name: str + item_name: str | FieldResolver """The item to check for""" - count: int = 1 + count: int | FieldResolver = 1 """The count the player is required to have""" @override def _instantiate(self, world: TWorld) -> Rule.Resolved: return self.Resolved( - self.item_name, - self.count, + resolve_field(self.item_name, world, str), + count=resolve_field(self.count, world, int), player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False), ) @override def __str__(self) -> str: - count = f", count={self.count}" if self.count > 1 else "" + count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else "" options = f", options={self.options}" if self.options else "" return f"{self.__class__.__name__}({self.item_name}{count}{options})" @@ -991,7 +1007,7 @@ class HasAny(Rule[TWorld], game="Archipelago"): class HasAllCounts(Rule[TWorld], game="Archipelago"): """A rule that checks if the player has all of the specified counts of the given items""" - item_counts: dict[str, int] + item_counts: Mapping[str, int | FieldResolver] """A mapping of item name to count to check for""" @override @@ -1002,12 +1018,30 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"): if len(self.item_counts) == 1: item = next(iter(self.item_counts)) return Has(item, self.item_counts[item]).resolve(world) + item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items()) return self.Resolved( - tuple(self.item_counts.items()), + item_counts, player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False), ) + @override + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + output["args"]["item_counts"] = { + key: value.to_dict() if isinstance(value, FieldResolver) else value + for key, value in output["args"]["item_counts"].items() + } + return output + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + args = data.get("args", {}) + item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game) + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False)) + @override def __str__(self) -> str: items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()]) @@ -1096,7 +1130,7 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"): class HasAnyCount(Rule[TWorld], game="Archipelago"): """A rule that checks if the player has any of the specified counts of the given items""" - item_counts: dict[str, int] + item_counts: Mapping[str, int | FieldResolver] """A mapping of item name to count to check for""" @override @@ -1107,12 +1141,30 @@ class HasAnyCount(Rule[TWorld], game="Archipelago"): if len(self.item_counts) == 1: item = next(iter(self.item_counts)) return Has(item, self.item_counts[item]).resolve(world) + item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items()) return self.Resolved( - tuple(self.item_counts.items()), + item_counts, player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False), ) + @override + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + output["args"]["item_counts"] = { + key: value.to_dict() if isinstance(value, FieldResolver) else value + for key, value in output["args"]["item_counts"].items() + } + return output + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + args = data.get("args", {}) + item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game) + options = OptionFilter.multiple_from_dict(data.get("options", ())) + return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False)) + @override def __str__(self) -> str: items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()]) @@ -1204,13 +1256,13 @@ class HasFromList(Rule[TWorld], game="Archipelago"): item_names: tuple[str, ...] """A tuple of item names to check for""" - count: int = 1 + count: int | FieldResolver = 1 """The number of items the player needs to have""" def __init__( self, *item_names: str, - count: int = 1, + count: int | FieldResolver = 1, options: Iterable[OptionFilter] = (), filtered_resolution: bool = False, ) -> None: @@ -1227,7 +1279,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"): return Has(self.item_names[0], self.count).resolve(world) return self.Resolved( self.item_names, - self.count, + count=resolve_field(self.count, world, int), player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False), ) @@ -1235,7 +1287,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"): @override @classmethod def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: - args = {**data.get("args", {})} + args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game) item_names = args.pop("item_names", ()) options = OptionFilter.multiple_from_dict(data.get("options", ())) return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False)) @@ -1338,13 +1390,13 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"): item_names: tuple[str, ...] """A tuple of item names to check for""" - count: int = 1 + count: int | FieldResolver = 1 """The number of items the player needs to have""" def __init__( self, *item_names: str, - count: int = 1, + count: int | FieldResolver = 1, options: Iterable[OptionFilter] = (), filtered_resolution: bool = False, ) -> None: @@ -1354,14 +1406,15 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"): @override def _instantiate(self, world: TWorld) -> Rule.Resolved: - if len(self.item_names) == 0 or len(self.item_names) < self.count: + count = resolve_field(self.count, world, int) + if len(self.item_names) == 0 or len(self.item_names) < count: # match state.has_from_list_unique return False_().resolve(world) if len(self.item_names) == 1: return Has(self.item_names[0]).resolve(world) return self.Resolved( self.item_names, - self.count, + count, player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False), ) @@ -1369,7 +1422,7 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"): @override @classmethod def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: - args = {**data.get("args", {})} + args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game) item_names = args.pop("item_names", ()) options = OptionFilter.multiple_from_dict(data.get("options", ())) return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False)) @@ -1468,7 +1521,7 @@ class HasGroup(Rule[TWorld], game="Archipelago"): item_name_group: str """The name of the item group containing the items""" - count: int = 1 + count: int | FieldResolver = 1 """The number of items the player needs to have""" @override @@ -1477,14 +1530,14 @@ class HasGroup(Rule[TWorld], game="Archipelago"): return self.Resolved( self.item_name_group, item_names, - self.count, + count=resolve_field(self.count, world, int), player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False), ) @override def __str__(self) -> str: - count = f", count={self.count}" if self.count > 1 else "" + count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else "" options = f", options={self.options}" if self.options else "" return f"{self.__class__.__name__}({self.item_name_group}{count}{options})" @@ -1542,7 +1595,7 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"): item_name_group: str """The name of the item group containing the items""" - count: int = 1 + count: int | FieldResolver = 1 """The number of items the player needs to have""" @override @@ -1551,14 +1604,14 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"): return self.Resolved( self.item_name_group, item_names, - self.count, + count=resolve_field(self.count, world, int), player=world.player, caching_enabled=getattr(world, "rule_caching_enabled", False), ) @override def __str__(self) -> str: - count = f", count={self.count}" if self.count > 1 else "" + count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else "" options = f", options={self.options}" if self.options else "" return f"{self.__class__.__name__}({self.item_name_group}{count}{options})" diff --git a/test/general/test_rule_builder.py b/test/general/test_rule_builder.py index 81003dcd87..85e239175d 100644 --- a/test/general/test_rule_builder.py +++ b/test/general/test_rule_builder.py @@ -6,8 +6,9 @@ from typing_extensions import override from BaseClasses import CollectionState, Item, ItemClassification, Location, MultiWorld, Region from NetUtils import JSONMessagePart -from Options import Choice, FreeText, Option, OptionSet, PerGameCommonOptions, Toggle +from Options import Choice, FreeText, Option, OptionSet, PerGameCommonOptions, Range, Toggle from rule_builder.cached_world import CachedRuleBuilderWorld +from rule_builder.field_resolvers import FieldResolver, FromOption, FromWorldAttr, resolve_field from rule_builder.options import Operator, OptionFilter from rule_builder.rules import ( And, @@ -59,12 +60,20 @@ class SetOption(OptionSet): valid_keys: ClassVar[set[str]] = {"one", "two", "three"} # pyright: ignore[reportIncompatibleVariableOverride] +class RangeOption(Range): + auto_display_name = True + range_start = 1 + range_end = 10 + default = 5 + + @dataclass class RuleBuilderOptions(PerGameCommonOptions): toggle_option: ToggleOption choice_option: ChoiceOption text_option: FreeTextOption set_option: SetOption + range_option: RangeOption GAME_NAME = "Rule Builder Test Game" @@ -659,14 +668,15 @@ class TestRules(RuleBuilderTestCase): self.assertFalse(resolved_rule(self.state)) def test_has_any_count(self) -> None: - item_counts = {"Item 1": 1, "Item 2": 2} + item_counts: dict[str, int | FieldResolver] = {"Item 1": 1, "Item 2": 2} rule = HasAnyCount(item_counts) resolved_rule = rule.resolve(self.world) self.world.register_rule_dependencies(resolved_rule) for item_name, count in item_counts.items(): item = self.world.create_item(item_name) - for _ in range(count): + num_items = resolve_field(count, self.world, int) + for _ in range(num_items): self.assertFalse(resolved_rule(self.state)) self.state.collect(item) self.assertTrue(resolved_rule(self.state)) @@ -763,7 +773,7 @@ class TestSerialization(RuleBuilderTestCase): rule: ClassVar[Rule[Any]] = And( Or( - Has("i1", count=4), + Has("i1", count=FromOption(RangeOption)), HasFromList("i2", "i3", "i4", count=2), HasAnyCount({"i5": 2, "i6": 3}), options=[OptionFilter(ToggleOption, 0)], @@ -771,7 +781,7 @@ class TestSerialization(RuleBuilderTestCase): Or( HasAll("i7", "i8"), HasAllCounts( - {"i9": 1, "i10": 5}, + {"i9": 1, "i10": FromWorldAttr("instance_data.i10_count")}, options=[OptionFilter(ToggleOption, 1, operator="ne")], filtered_resolution=True, ), @@ -811,7 +821,14 @@ class TestSerialization(RuleBuilderTestCase): "rule": "Has", "options": [], "filtered_resolution": False, - "args": {"item_name": "i1", "count": 4}, + "args": { + "item_name": "i1", + "count": { + "resolver": "FromOption", + "option": "test.general.test_rule_builder.RangeOption", + "field": "value", + }, + }, }, { "rule": "HasFromList", @@ -848,7 +865,12 @@ class TestSerialization(RuleBuilderTestCase): }, ], "filtered_resolution": True, - "args": {"item_counts": {"i9": 1, "i10": 5}}, + "args": { + "item_counts": { + "i9": 1, + "i10": {"resolver": "FromWorldAttr", "name": "instance_data.i10_count"}, + } + }, }, { "rule": "CanReachRegion", @@ -923,7 +945,7 @@ class TestSerialization(RuleBuilderTestCase): multiworld = setup_solo_multiworld(self.world_cls, steps=(), seed=0) world = multiworld.worlds[1] deserialized_rule = world.rule_from_dict(self.rule_dict) - self.assertEqual(deserialized_rule, self.rule, str(deserialized_rule)) + self.assertEqual(deserialized_rule, self.rule, f"\n{deserialized_rule}\n{self.rule}") class TestExplain(RuleBuilderTestCase): @@ -1342,3 +1364,32 @@ class TestExplain(RuleBuilderTestCase): "& False)", ) assert str(self.resolved_rule) == " ".join(expected) + + +@classvar_matrix( + rules=( + ( + Has("A", FromOption(RangeOption)), + Has.Resolved("A", count=5, player=1), + ), + ( + Has("B", FromWorldAttr("pre_calculated")), + Has.Resolved("B", count=3, player=1), + ), + ( + Has("C", FromWorldAttr("instance_data.key")), + Has.Resolved("C", count=7, player=1), + ), + ) +) +class TestFieldResolvers(RuleBuilderTestCase): + rules: ClassVar[tuple[Rule[Any], Rule.Resolved]] + + def test_simplify(self) -> None: + multiworld = setup_solo_multiworld(self.world_cls, steps=("generate_early",), seed=0) + world = multiworld.worlds[1] + world.pre_calculated = 3 # pyright: ignore[reportAttributeAccessIssue] + world.instance_data = {"key": 7} # pyright: ignore[reportAttributeAccessIssue] + rule, expected = self.rules + resolved_rule = rule.resolve(world) + self.assertEqual(resolved_rule, expected, f"\n{resolved_rule}\n{expected}")