mirror of
https://github.com/ArchipelagoMW/Archipelago.git
synced 2026-03-30 12:43:24 -07:00
Rule Builder: Add field resolvers (#5919)
This commit is contained in:
1
.github/pyright-config.json
vendored
1
.github/pyright-config.json
vendored
@@ -3,6 +3,7 @@
|
||||
"../BizHawkClient.py",
|
||||
"../Patch.py",
|
||||
"../rule_builder/cached_world.py",
|
||||
"../rule_builder/field_resolvers.py",
|
||||
"../rule_builder/options.py",
|
||||
"../rule_builder/rules.py",
|
||||
"../test/param.py",
|
||||
|
||||
@@ -129,6 +129,42 @@ common_rule_only_on_easy = common_rule & easy_filter
|
||||
common_rule_skipped_on_easy = common_rule | easy_filter
|
||||
```
|
||||
|
||||
### Field resolvers
|
||||
|
||||
When creating rules you may sometimes need to set a field to a value that depends on the world instance. You can use a `FieldResolver` to define how to populate that field when the rule is being resolved.
|
||||
|
||||
There are two build-in field resolvers:
|
||||
|
||||
- `FromOption`: Resolves to the value of the given option
|
||||
- `FromWorldAttr`: Resolves to the value of the given world instance attribute, can specify a dotted path `a.b.c` to get a nested attribute or dict item
|
||||
|
||||
```python
|
||||
world.options.mcguffin_count = 5
|
||||
world.precalculated_value = 99
|
||||
rule = (
|
||||
Has("A", count=FromOption(McguffinCount))
|
||||
| HasGroup("Important items", count=FromWorldAttr("precalculated_value"))
|
||||
)
|
||||
# Results in Has("A", count=5) | HasGroup("Important items", count=99)
|
||||
```
|
||||
|
||||
You can define your own resolvers by creating a class that inherits from `FieldResolver`, provides your game name, and implements a `resolve` function:
|
||||
|
||||
```python
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FromCustomResolution(FieldResolver, game="MyGame"):
|
||||
modifier: str
|
||||
|
||||
@override
|
||||
def resolve(self, world: "World") -> Any:
|
||||
return some_math_calculation(world, self.modifier)
|
||||
|
||||
|
||||
rule = Has("Combat Level", count=FromCustomResolution("combat"))
|
||||
```
|
||||
|
||||
If you want to support rule serialization and your resolver contains non-serializable properties you may need to override `to_dict` or `from_dict`.
|
||||
|
||||
## Enabling caching
|
||||
|
||||
The rule builder provides a `CachedRuleBuilderWorld` base class for your `World` class that enables caching on your rules.
|
||||
|
||||
162
rule_builder/field_resolvers.py
Normal file
162
rule_builder/field_resolvers.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import dataclasses
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeVar, cast, overload
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from Options import Option
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from worlds.AutoWorld import World
|
||||
|
||||
|
||||
class FieldResolverRegister:
|
||||
"""A container class to contain world custom resolvers"""
|
||||
|
||||
custom_resolvers: ClassVar[dict[str, dict[str, type["FieldResolver"]]]] = {}
|
||||
"""
|
||||
A mapping of game name to mapping of resolver name to resolver class
|
||||
to hold custom resolvers implemented by worlds
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_resolver_cls(cls, game_name: str, resolver_name: str) -> type["FieldResolver"]:
|
||||
"""Returns the world-registered or default resolver with the given name"""
|
||||
custom_resolver_classes = cls.custom_resolvers.get(game_name, {})
|
||||
if resolver_name not in DEFAULT_RESOLVERS and resolver_name not in custom_resolver_classes:
|
||||
raise ValueError(f"Resolver '{resolver_name}' for game '{game_name}' not found")
|
||||
return custom_resolver_classes.get(resolver_name) or DEFAULT_RESOLVERS[resolver_name]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FieldResolver(ABC):
|
||||
@abstractmethod
|
||||
def resolve(self, world: "World") -> Any: ...
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Returns a JSON compatible dict representation of this resolver"""
|
||||
fields = {field.name: getattr(self, field.name, None) for field in dataclasses.fields(self)}
|
||||
return {
|
||||
"resolver": self.__class__.__name__,
|
||||
**fields,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
"""Returns a new instance of this resolver from a serialized dict representation"""
|
||||
assert data.get("resolver", None) == cls.__name__
|
||||
return cls(**{k: v for k, v in data.items() if k != "resolver"})
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, /, game: str) -> None:
|
||||
if game != "Archipelago":
|
||||
custom_resolvers = FieldResolverRegister.custom_resolvers.setdefault(game, {})
|
||||
if cls.__qualname__ in custom_resolvers:
|
||||
raise TypeError(f"Resolver {cls.__qualname__} has already been registered for game {game}")
|
||||
custom_resolvers[cls.__qualname__] = cls
|
||||
elif cls.__module__ != "rule_builder.field_resolvers":
|
||||
raise TypeError("You cannot define custom resolvers for the base Archipelago world")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FromOption(FieldResolver, game="Archipelago"):
|
||||
option: type[Option[Any]]
|
||||
field: str = "value"
|
||||
|
||||
@override
|
||||
def resolve(self, world: "World") -> Any:
|
||||
option_name = next(
|
||||
(name for name, cls in world.options.__class__.type_hints.items() if cls is self.option),
|
||||
None,
|
||||
)
|
||||
|
||||
if option_name is None:
|
||||
raise ValueError(
|
||||
f"Cannot find option {self.option.__name__} in options class {world.options.__class__.__name__}"
|
||||
)
|
||||
opt = cast(Option[Any] | None, getattr(world.options, option_name, None))
|
||||
if opt is None:
|
||||
raise ValueError(f"Invalid option: {option_name}")
|
||||
return getattr(opt, self.field)
|
||||
|
||||
@override
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"resolver": "FromOption",
|
||||
"option": f"{self.option.__module__}.{self.option.__name__}",
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
if "option" not in data:
|
||||
raise ValueError("Missing required option")
|
||||
|
||||
option_path = data["option"]
|
||||
try:
|
||||
option_mod_name, option_cls_name = option_path.rsplit(".", 1)
|
||||
option_module = importlib.import_module(option_mod_name)
|
||||
option = getattr(option_module, option_cls_name, None)
|
||||
except (ValueError, ImportError) as e:
|
||||
raise ValueError(f"Cannot parse option '{option_path}'") from e
|
||||
if option is None or not issubclass(option, Option):
|
||||
raise ValueError(f"Invalid option '{option_path}' returns type '{option}' instead of Option subclass")
|
||||
|
||||
return cls(cast(type[Option[Any]], option), data.get("field", "value"))
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
field = f".{self.field}" if self.field != "value" else ""
|
||||
return f"FromOption({self.option.__name__}{field})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FromWorldAttr(FieldResolver, game="Archipelago"):
|
||||
name: str
|
||||
|
||||
@override
|
||||
def resolve(self, world: "World") -> Any:
|
||||
obj: Any = world
|
||||
for field in self.name.split("."):
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, Mapping):
|
||||
obj = obj.get(field, None) # pyright: ignore[reportUnknownMemberType]
|
||||
else:
|
||||
obj = getattr(obj, field, None)
|
||||
return obj
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"FromWorldAttr({self.name})"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def resolve_field(field: Any, world: "World", expected_type: type[T]) -> T: ...
|
||||
@overload
|
||||
def resolve_field(field: Any, world: "World", expected_type: None = None) -> Any: ...
|
||||
def resolve_field(field: Any, world: "World", expected_type: type[T] | None = None) -> T | Any:
|
||||
if isinstance(field, FieldResolver):
|
||||
field = field.resolve(world)
|
||||
if expected_type:
|
||||
assert isinstance(field, expected_type), f"Expected type {expected_type} but got {type(field)}"
|
||||
return field
|
||||
|
||||
|
||||
DEFAULT_RESOLVERS = {
|
||||
resolver_name: resolver_class
|
||||
for resolver_name, resolver_class in locals().items()
|
||||
if isinstance(resolver_class, type)
|
||||
and issubclass(resolver_class, FieldResolver)
|
||||
and resolver_class is not FieldResolver
|
||||
}
|
||||
@@ -7,6 +7,7 @@ from typing_extensions import TypeVar, dataclass_transform, override
|
||||
from BaseClasses import CollectionState
|
||||
from NetUtils import JSONMessagePart
|
||||
|
||||
from .field_resolvers import FieldResolver, FieldResolverRegister, resolve_field
|
||||
from .options import OptionFilter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -108,11 +109,14 @@ class Rule(Generic[TWorld]):
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Returns a JSON compatible dict representation of this rule"""
|
||||
args = {
|
||||
field.name: getattr(self, field.name, None)
|
||||
for field in dataclasses.fields(self)
|
||||
if field.name not in ("options", "filtered_resolution")
|
||||
}
|
||||
args = {}
|
||||
for field in dataclasses.fields(self):
|
||||
if field.name in ("options", "filtered_resolution"):
|
||||
continue
|
||||
value = getattr(self, field.name, None)
|
||||
if isinstance(value, FieldResolver):
|
||||
value = value.to_dict()
|
||||
args[field.name] = value
|
||||
return {
|
||||
"rule": self.__class__.__qualname__,
|
||||
"options": [o.to_dict() for o in self.options],
|
||||
@@ -124,7 +128,19 @@ class Rule(Generic[TWorld]):
|
||||
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
|
||||
"""Returns a new instance of this rule from a serialized dict representation"""
|
||||
options = OptionFilter.multiple_from_dict(data.get("options", ()))
|
||||
return cls(**data.get("args", {}), options=options, filtered_resolution=data.get("filtered_resolution", False))
|
||||
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
|
||||
return cls(**args, options=options, filtered_resolution=data.get("filtered_resolution", False))
|
||||
|
||||
@classmethod
|
||||
def _parse_field_resolvers(cls, data: Mapping[str, Any], game_name: str) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
for name, value in data.items():
|
||||
if isinstance(value, dict) and "resolver" in value:
|
||||
resolver_cls = FieldResolverRegister.get_resolver_cls(game_name, value["resolver"]) # pyright: ignore[reportUnknownArgumentType]
|
||||
result[name] = resolver_cls.from_dict(value) # pyright: ignore[reportUnknownArgumentType]
|
||||
else:
|
||||
result[name] = value
|
||||
return result
|
||||
|
||||
def __and__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]":
|
||||
"""Combines two rules or a rule and an option filter into an And rule"""
|
||||
@@ -688,24 +704,24 @@ class Filtered(WrapperRule[TWorld], game="Archipelago"):
|
||||
class Has(Rule[TWorld], game="Archipelago"):
|
||||
"""A rule that checks if the player has at least `count` of a given item"""
|
||||
|
||||
item_name: str
|
||||
item_name: str | FieldResolver
|
||||
"""The item to check for"""
|
||||
|
||||
count: int = 1
|
||||
count: int | FieldResolver = 1
|
||||
"""The count the player is required to have"""
|
||||
|
||||
@override
|
||||
def _instantiate(self, world: TWorld) -> Rule.Resolved:
|
||||
return self.Resolved(
|
||||
self.item_name,
|
||||
self.count,
|
||||
resolve_field(self.item_name, world, str),
|
||||
count=resolve_field(self.count, world, int),
|
||||
player=world.player,
|
||||
caching_enabled=getattr(world, "rule_caching_enabled", False),
|
||||
)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
count = f", count={self.count}" if self.count > 1 else ""
|
||||
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
|
||||
options = f", options={self.options}" if self.options else ""
|
||||
return f"{self.__class__.__name__}({self.item_name}{count}{options})"
|
||||
|
||||
@@ -991,7 +1007,7 @@ class HasAny(Rule[TWorld], game="Archipelago"):
|
||||
class HasAllCounts(Rule[TWorld], game="Archipelago"):
|
||||
"""A rule that checks if the player has all of the specified counts of the given items"""
|
||||
|
||||
item_counts: dict[str, int]
|
||||
item_counts: Mapping[str, int | FieldResolver]
|
||||
"""A mapping of item name to count to check for"""
|
||||
|
||||
@override
|
||||
@@ -1002,12 +1018,30 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"):
|
||||
if len(self.item_counts) == 1:
|
||||
item = next(iter(self.item_counts))
|
||||
return Has(item, self.item_counts[item]).resolve(world)
|
||||
item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items())
|
||||
return self.Resolved(
|
||||
tuple(self.item_counts.items()),
|
||||
item_counts,
|
||||
player=world.player,
|
||||
caching_enabled=getattr(world, "rule_caching_enabled", False),
|
||||
)
|
||||
|
||||
@override
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
output = super().to_dict()
|
||||
output["args"]["item_counts"] = {
|
||||
key: value.to_dict() if isinstance(value, FieldResolver) else value
|
||||
for key, value in output["args"]["item_counts"].items()
|
||||
}
|
||||
return output
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
|
||||
args = data.get("args", {})
|
||||
item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game)
|
||||
options = OptionFilter.multiple_from_dict(data.get("options", ()))
|
||||
return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False))
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()])
|
||||
@@ -1096,7 +1130,7 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"):
|
||||
class HasAnyCount(Rule[TWorld], game="Archipelago"):
|
||||
"""A rule that checks if the player has any of the specified counts of the given items"""
|
||||
|
||||
item_counts: dict[str, int]
|
||||
item_counts: Mapping[str, int | FieldResolver]
|
||||
"""A mapping of item name to count to check for"""
|
||||
|
||||
@override
|
||||
@@ -1107,12 +1141,30 @@ class HasAnyCount(Rule[TWorld], game="Archipelago"):
|
||||
if len(self.item_counts) == 1:
|
||||
item = next(iter(self.item_counts))
|
||||
return Has(item, self.item_counts[item]).resolve(world)
|
||||
item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items())
|
||||
return self.Resolved(
|
||||
tuple(self.item_counts.items()),
|
||||
item_counts,
|
||||
player=world.player,
|
||||
caching_enabled=getattr(world, "rule_caching_enabled", False),
|
||||
)
|
||||
|
||||
@override
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
output = super().to_dict()
|
||||
output["args"]["item_counts"] = {
|
||||
key: value.to_dict() if isinstance(value, FieldResolver) else value
|
||||
for key, value in output["args"]["item_counts"].items()
|
||||
}
|
||||
return output
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
|
||||
args = data.get("args", {})
|
||||
item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game)
|
||||
options = OptionFilter.multiple_from_dict(data.get("options", ()))
|
||||
return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False))
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()])
|
||||
@@ -1204,13 +1256,13 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
|
||||
item_names: tuple[str, ...]
|
||||
"""A tuple of item names to check for"""
|
||||
|
||||
count: int = 1
|
||||
count: int | FieldResolver = 1
|
||||
"""The number of items the player needs to have"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*item_names: str,
|
||||
count: int = 1,
|
||||
count: int | FieldResolver = 1,
|
||||
options: Iterable[OptionFilter] = (),
|
||||
filtered_resolution: bool = False,
|
||||
) -> None:
|
||||
@@ -1227,7 +1279,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
|
||||
return Has(self.item_names[0], self.count).resolve(world)
|
||||
return self.Resolved(
|
||||
self.item_names,
|
||||
self.count,
|
||||
count=resolve_field(self.count, world, int),
|
||||
player=world.player,
|
||||
caching_enabled=getattr(world, "rule_caching_enabled", False),
|
||||
)
|
||||
@@ -1235,7 +1287,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
|
||||
@override
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
|
||||
args = {**data.get("args", {})}
|
||||
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
|
||||
item_names = args.pop("item_names", ())
|
||||
options = OptionFilter.multiple_from_dict(data.get("options", ()))
|
||||
return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False))
|
||||
@@ -1338,13 +1390,13 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
|
||||
item_names: tuple[str, ...]
|
||||
"""A tuple of item names to check for"""
|
||||
|
||||
count: int = 1
|
||||
count: int | FieldResolver = 1
|
||||
"""The number of items the player needs to have"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*item_names: str,
|
||||
count: int = 1,
|
||||
count: int | FieldResolver = 1,
|
||||
options: Iterable[OptionFilter] = (),
|
||||
filtered_resolution: bool = False,
|
||||
) -> None:
|
||||
@@ -1354,14 +1406,15 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
|
||||
|
||||
@override
|
||||
def _instantiate(self, world: TWorld) -> Rule.Resolved:
|
||||
if len(self.item_names) == 0 or len(self.item_names) < self.count:
|
||||
count = resolve_field(self.count, world, int)
|
||||
if len(self.item_names) == 0 or len(self.item_names) < count:
|
||||
# match state.has_from_list_unique
|
||||
return False_().resolve(world)
|
||||
if len(self.item_names) == 1:
|
||||
return Has(self.item_names[0]).resolve(world)
|
||||
return self.Resolved(
|
||||
self.item_names,
|
||||
self.count,
|
||||
count,
|
||||
player=world.player,
|
||||
caching_enabled=getattr(world, "rule_caching_enabled", False),
|
||||
)
|
||||
@@ -1369,7 +1422,7 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
|
||||
@override
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
|
||||
args = {**data.get("args", {})}
|
||||
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
|
||||
item_names = args.pop("item_names", ())
|
||||
options = OptionFilter.multiple_from_dict(data.get("options", ()))
|
||||
return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False))
|
||||
@@ -1468,7 +1521,7 @@ class HasGroup(Rule[TWorld], game="Archipelago"):
|
||||
item_name_group: str
|
||||
"""The name of the item group containing the items"""
|
||||
|
||||
count: int = 1
|
||||
count: int | FieldResolver = 1
|
||||
"""The number of items the player needs to have"""
|
||||
|
||||
@override
|
||||
@@ -1477,14 +1530,14 @@ class HasGroup(Rule[TWorld], game="Archipelago"):
|
||||
return self.Resolved(
|
||||
self.item_name_group,
|
||||
item_names,
|
||||
self.count,
|
||||
count=resolve_field(self.count, world, int),
|
||||
player=world.player,
|
||||
caching_enabled=getattr(world, "rule_caching_enabled", False),
|
||||
)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
count = f", count={self.count}" if self.count > 1 else ""
|
||||
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
|
||||
options = f", options={self.options}" if self.options else ""
|
||||
return f"{self.__class__.__name__}({self.item_name_group}{count}{options})"
|
||||
|
||||
@@ -1542,7 +1595,7 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"):
|
||||
item_name_group: str
|
||||
"""The name of the item group containing the items"""
|
||||
|
||||
count: int = 1
|
||||
count: int | FieldResolver = 1
|
||||
"""The number of items the player needs to have"""
|
||||
|
||||
@override
|
||||
@@ -1551,14 +1604,14 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"):
|
||||
return self.Resolved(
|
||||
self.item_name_group,
|
||||
item_names,
|
||||
self.count,
|
||||
count=resolve_field(self.count, world, int),
|
||||
player=world.player,
|
||||
caching_enabled=getattr(world, "rule_caching_enabled", False),
|
||||
)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
count = f", count={self.count}" if self.count > 1 else ""
|
||||
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
|
||||
options = f", options={self.options}" if self.options else ""
|
||||
return f"{self.__class__.__name__}({self.item_name_group}{count}{options})"
|
||||
|
||||
|
||||
@@ -6,8 +6,9 @@ from typing_extensions import override
|
||||
|
||||
from BaseClasses import CollectionState, Item, ItemClassification, Location, MultiWorld, Region
|
||||
from NetUtils import JSONMessagePart
|
||||
from Options import Choice, FreeText, Option, OptionSet, PerGameCommonOptions, Toggle
|
||||
from Options import Choice, FreeText, Option, OptionSet, PerGameCommonOptions, Range, Toggle
|
||||
from rule_builder.cached_world import CachedRuleBuilderWorld
|
||||
from rule_builder.field_resolvers import FieldResolver, FromOption, FromWorldAttr, resolve_field
|
||||
from rule_builder.options import Operator, OptionFilter
|
||||
from rule_builder.rules import (
|
||||
And,
|
||||
@@ -59,12 +60,20 @@ class SetOption(OptionSet):
|
||||
valid_keys: ClassVar[set[str]] = {"one", "two", "three"} # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class RangeOption(Range):
|
||||
auto_display_name = True
|
||||
range_start = 1
|
||||
range_end = 10
|
||||
default = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuleBuilderOptions(PerGameCommonOptions):
|
||||
toggle_option: ToggleOption
|
||||
choice_option: ChoiceOption
|
||||
text_option: FreeTextOption
|
||||
set_option: SetOption
|
||||
range_option: RangeOption
|
||||
|
||||
|
||||
GAME_NAME = "Rule Builder Test Game"
|
||||
@@ -659,14 +668,15 @@ class TestRules(RuleBuilderTestCase):
|
||||
self.assertFalse(resolved_rule(self.state))
|
||||
|
||||
def test_has_any_count(self) -> None:
|
||||
item_counts = {"Item 1": 1, "Item 2": 2}
|
||||
item_counts: dict[str, int | FieldResolver] = {"Item 1": 1, "Item 2": 2}
|
||||
rule = HasAnyCount(item_counts)
|
||||
resolved_rule = rule.resolve(self.world)
|
||||
self.world.register_rule_dependencies(resolved_rule)
|
||||
|
||||
for item_name, count in item_counts.items():
|
||||
item = self.world.create_item(item_name)
|
||||
for _ in range(count):
|
||||
num_items = resolve_field(count, self.world, int)
|
||||
for _ in range(num_items):
|
||||
self.assertFalse(resolved_rule(self.state))
|
||||
self.state.collect(item)
|
||||
self.assertTrue(resolved_rule(self.state))
|
||||
@@ -763,7 +773,7 @@ class TestSerialization(RuleBuilderTestCase):
|
||||
|
||||
rule: ClassVar[Rule[Any]] = And(
|
||||
Or(
|
||||
Has("i1", count=4),
|
||||
Has("i1", count=FromOption(RangeOption)),
|
||||
HasFromList("i2", "i3", "i4", count=2),
|
||||
HasAnyCount({"i5": 2, "i6": 3}),
|
||||
options=[OptionFilter(ToggleOption, 0)],
|
||||
@@ -771,7 +781,7 @@ class TestSerialization(RuleBuilderTestCase):
|
||||
Or(
|
||||
HasAll("i7", "i8"),
|
||||
HasAllCounts(
|
||||
{"i9": 1, "i10": 5},
|
||||
{"i9": 1, "i10": FromWorldAttr("instance_data.i10_count")},
|
||||
options=[OptionFilter(ToggleOption, 1, operator="ne")],
|
||||
filtered_resolution=True,
|
||||
),
|
||||
@@ -811,7 +821,14 @@ class TestSerialization(RuleBuilderTestCase):
|
||||
"rule": "Has",
|
||||
"options": [],
|
||||
"filtered_resolution": False,
|
||||
"args": {"item_name": "i1", "count": 4},
|
||||
"args": {
|
||||
"item_name": "i1",
|
||||
"count": {
|
||||
"resolver": "FromOption",
|
||||
"option": "test.general.test_rule_builder.RangeOption",
|
||||
"field": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"rule": "HasFromList",
|
||||
@@ -848,7 +865,12 @@ class TestSerialization(RuleBuilderTestCase):
|
||||
},
|
||||
],
|
||||
"filtered_resolution": True,
|
||||
"args": {"item_counts": {"i9": 1, "i10": 5}},
|
||||
"args": {
|
||||
"item_counts": {
|
||||
"i9": 1,
|
||||
"i10": {"resolver": "FromWorldAttr", "name": "instance_data.i10_count"},
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"rule": "CanReachRegion",
|
||||
@@ -923,7 +945,7 @@ class TestSerialization(RuleBuilderTestCase):
|
||||
multiworld = setup_solo_multiworld(self.world_cls, steps=(), seed=0)
|
||||
world = multiworld.worlds[1]
|
||||
deserialized_rule = world.rule_from_dict(self.rule_dict)
|
||||
self.assertEqual(deserialized_rule, self.rule, str(deserialized_rule))
|
||||
self.assertEqual(deserialized_rule, self.rule, f"\n{deserialized_rule}\n{self.rule}")
|
||||
|
||||
|
||||
class TestExplain(RuleBuilderTestCase):
|
||||
@@ -1342,3 +1364,32 @@ class TestExplain(RuleBuilderTestCase):
|
||||
"& False)",
|
||||
)
|
||||
assert str(self.resolved_rule) == " ".join(expected)
|
||||
|
||||
|
||||
@classvar_matrix(
|
||||
rules=(
|
||||
(
|
||||
Has("A", FromOption(RangeOption)),
|
||||
Has.Resolved("A", count=5, player=1),
|
||||
),
|
||||
(
|
||||
Has("B", FromWorldAttr("pre_calculated")),
|
||||
Has.Resolved("B", count=3, player=1),
|
||||
),
|
||||
(
|
||||
Has("C", FromWorldAttr("instance_data.key")),
|
||||
Has.Resolved("C", count=7, player=1),
|
||||
),
|
||||
)
|
||||
)
|
||||
class TestFieldResolvers(RuleBuilderTestCase):
|
||||
rules: ClassVar[tuple[Rule[Any], Rule.Resolved]]
|
||||
|
||||
def test_simplify(self) -> None:
|
||||
multiworld = setup_solo_multiworld(self.world_cls, steps=("generate_early",), seed=0)
|
||||
world = multiworld.worlds[1]
|
||||
world.pre_calculated = 3 # pyright: ignore[reportAttributeAccessIssue]
|
||||
world.instance_data = {"key": 7} # pyright: ignore[reportAttributeAccessIssue]
|
||||
rule, expected = self.rules
|
||||
resolved_rule = rule.resolve(world)
|
||||
self.assertEqual(resolved_rule, expected, f"\n{resolved_rule}\n{expected}")
|
||||
|
||||
Reference in New Issue
Block a user