Rule Builder: Add field resolvers (#5919)

This commit is contained in:
Ian Robinson
2026-03-30 12:19:10 -04:00
committed by GitHub
parent 58a6407040
commit c640d2fa24
5 changed files with 341 additions and 38 deletions

View File

@@ -0,0 +1,162 @@
import dataclasses
import importlib
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeVar, cast, overload
from typing_extensions import override
from Options import Option
if TYPE_CHECKING:
from worlds.AutoWorld import World
class FieldResolverRegister:
"""A container class to contain world custom resolvers"""
custom_resolvers: ClassVar[dict[str, dict[str, type["FieldResolver"]]]] = {}
"""
A mapping of game name to mapping of resolver name to resolver class
to hold custom resolvers implemented by worlds
"""
@classmethod
def get_resolver_cls(cls, game_name: str, resolver_name: str) -> type["FieldResolver"]:
"""Returns the world-registered or default resolver with the given name"""
custom_resolver_classes = cls.custom_resolvers.get(game_name, {})
if resolver_name not in DEFAULT_RESOLVERS and resolver_name not in custom_resolver_classes:
raise ValueError(f"Resolver '{resolver_name}' for game '{game_name}' not found")
return custom_resolver_classes.get(resolver_name) or DEFAULT_RESOLVERS[resolver_name]
@dataclasses.dataclass(frozen=True)
class FieldResolver(ABC):
@abstractmethod
def resolve(self, world: "World") -> Any: ...
def to_dict(self) -> dict[str, Any]:
"""Returns a JSON compatible dict representation of this resolver"""
fields = {field.name: getattr(self, field.name, None) for field in dataclasses.fields(self)}
return {
"resolver": self.__class__.__name__,
**fields,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
"""Returns a new instance of this resolver from a serialized dict representation"""
assert data.get("resolver", None) == cls.__name__
return cls(**{k: v for k, v in data.items() if k != "resolver"})
@override
def __str__(self) -> str:
return self.__class__.__name__
@classmethod
def __init_subclass__(cls, /, game: str) -> None:
if game != "Archipelago":
custom_resolvers = FieldResolverRegister.custom_resolvers.setdefault(game, {})
if cls.__qualname__ in custom_resolvers:
raise TypeError(f"Resolver {cls.__qualname__} has already been registered for game {game}")
custom_resolvers[cls.__qualname__] = cls
elif cls.__module__ != "rule_builder.field_resolvers":
raise TypeError("You cannot define custom resolvers for the base Archipelago world")
@dataclasses.dataclass(frozen=True)
class FromOption(FieldResolver, game="Archipelago"):
option: type[Option[Any]]
field: str = "value"
@override
def resolve(self, world: "World") -> Any:
option_name = next(
(name for name, cls in world.options.__class__.type_hints.items() if cls is self.option),
None,
)
if option_name is None:
raise ValueError(
f"Cannot find option {self.option.__name__} in options class {world.options.__class__.__name__}"
)
opt = cast(Option[Any] | None, getattr(world.options, option_name, None))
if opt is None:
raise ValueError(f"Invalid option: {option_name}")
return getattr(opt, self.field)
@override
def to_dict(self) -> dict[str, Any]:
return {
"resolver": "FromOption",
"option": f"{self.option.__module__}.{self.option.__name__}",
"field": self.field,
}
@override
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
if "option" not in data:
raise ValueError("Missing required option")
option_path = data["option"]
try:
option_mod_name, option_cls_name = option_path.rsplit(".", 1)
option_module = importlib.import_module(option_mod_name)
option = getattr(option_module, option_cls_name, None)
except (ValueError, ImportError) as e:
raise ValueError(f"Cannot parse option '{option_path}'") from e
if option is None or not issubclass(option, Option):
raise ValueError(f"Invalid option '{option_path}' returns type '{option}' instead of Option subclass")
return cls(cast(type[Option[Any]], option), data.get("field", "value"))
@override
def __str__(self) -> str:
field = f".{self.field}" if self.field != "value" else ""
return f"FromOption({self.option.__name__}{field})"
@dataclasses.dataclass(frozen=True)
class FromWorldAttr(FieldResolver, game="Archipelago"):
name: str
@override
def resolve(self, world: "World") -> Any:
obj: Any = world
for field in self.name.split("."):
if obj is None:
return None
if isinstance(obj, Mapping):
obj = obj.get(field, None) # pyright: ignore[reportUnknownMemberType]
else:
obj = getattr(obj, field, None)
return obj
@override
def __str__(self) -> str:
return f"FromWorldAttr({self.name})"
T = TypeVar("T")
@overload
def resolve_field(field: Any, world: "World", expected_type: type[T]) -> T: ...
@overload
def resolve_field(field: Any, world: "World", expected_type: None = None) -> Any: ...
def resolve_field(field: Any, world: "World", expected_type: type[T] | None = None) -> T | Any:
if isinstance(field, FieldResolver):
field = field.resolve(world)
if expected_type:
assert isinstance(field, expected_type), f"Expected type {expected_type} but got {type(field)}"
return field
DEFAULT_RESOLVERS = {
resolver_name: resolver_class
for resolver_name, resolver_class in locals().items()
if isinstance(resolver_class, type)
and issubclass(resolver_class, FieldResolver)
and resolver_class is not FieldResolver
}

View File

@@ -7,6 +7,7 @@ from typing_extensions import TypeVar, dataclass_transform, override
from BaseClasses import CollectionState
from NetUtils import JSONMessagePart
from .field_resolvers import FieldResolver, FieldResolverRegister, resolve_field
from .options import OptionFilter
if TYPE_CHECKING:
@@ -108,11 +109,14 @@ class Rule(Generic[TWorld]):
def to_dict(self) -> dict[str, Any]:
"""Returns a JSON compatible dict representation of this rule"""
args = {
field.name: getattr(self, field.name, None)
for field in dataclasses.fields(self)
if field.name not in ("options", "filtered_resolution")
}
args = {}
for field in dataclasses.fields(self):
if field.name in ("options", "filtered_resolution"):
continue
value = getattr(self, field.name, None)
if isinstance(value, FieldResolver):
value = value.to_dict()
args[field.name] = value
return {
"rule": self.__class__.__qualname__,
"options": [o.to_dict() for o in self.options],
@@ -124,7 +128,19 @@ class Rule(Generic[TWorld]):
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
"""Returns a new instance of this rule from a serialized dict representation"""
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(**data.get("args", {}), options=options, filtered_resolution=data.get("filtered_resolution", False))
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
return cls(**args, options=options, filtered_resolution=data.get("filtered_resolution", False))
@classmethod
def _parse_field_resolvers(cls, data: Mapping[str, Any], game_name: str) -> dict[str, Any]:
result: dict[str, Any] = {}
for name, value in data.items():
if isinstance(value, dict) and "resolver" in value:
resolver_cls = FieldResolverRegister.get_resolver_cls(game_name, value["resolver"]) # pyright: ignore[reportUnknownArgumentType]
result[name] = resolver_cls.from_dict(value) # pyright: ignore[reportUnknownArgumentType]
else:
result[name] = value
return result
def __and__(self, other: "Rule[Any] | Iterable[OptionFilter] | OptionFilter") -> "Rule[TWorld]":
"""Combines two rules or a rule and an option filter into an And rule"""
@@ -688,24 +704,24 @@ class Filtered(WrapperRule[TWorld], game="Archipelago"):
class Has(Rule[TWorld], game="Archipelago"):
"""A rule that checks if the player has at least `count` of a given item"""
item_name: str
item_name: str | FieldResolver
"""The item to check for"""
count: int = 1
count: int | FieldResolver = 1
"""The count the player is required to have"""
@override
def _instantiate(self, world: TWorld) -> Rule.Resolved:
return self.Resolved(
self.item_name,
self.count,
resolve_field(self.item_name, world, str),
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def __str__(self) -> str:
count = f", count={self.count}" if self.count > 1 else ""
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
options = f", options={self.options}" if self.options else ""
return f"{self.__class__.__name__}({self.item_name}{count}{options})"
@@ -991,7 +1007,7 @@ class HasAny(Rule[TWorld], game="Archipelago"):
class HasAllCounts(Rule[TWorld], game="Archipelago"):
"""A rule that checks if the player has all of the specified counts of the given items"""
item_counts: dict[str, int]
item_counts: Mapping[str, int | FieldResolver]
"""A mapping of item name to count to check for"""
@override
@@ -1002,12 +1018,30 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"):
if len(self.item_counts) == 1:
item = next(iter(self.item_counts))
return Has(item, self.item_counts[item]).resolve(world)
item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items())
return self.Resolved(
tuple(self.item_counts.items()),
item_counts,
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
output["args"]["item_counts"] = {
key: value.to_dict() if isinstance(value, FieldResolver) else value
for key, value in output["args"]["item_counts"].items()
}
return output
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = data.get("args", {})
item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game)
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False))
@override
def __str__(self) -> str:
items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()])
@@ -1096,7 +1130,7 @@ class HasAllCounts(Rule[TWorld], game="Archipelago"):
class HasAnyCount(Rule[TWorld], game="Archipelago"):
"""A rule that checks if the player has any of the specified counts of the given items"""
item_counts: dict[str, int]
item_counts: Mapping[str, int | FieldResolver]
"""A mapping of item name to count to check for"""
@override
@@ -1107,12 +1141,30 @@ class HasAnyCount(Rule[TWorld], game="Archipelago"):
if len(self.item_counts) == 1:
item = next(iter(self.item_counts))
return Has(item, self.item_counts[item]).resolve(world)
item_counts = tuple((name, resolve_field(count, world, int)) for name, count in self.item_counts.items())
return self.Resolved(
tuple(self.item_counts.items()),
item_counts,
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
output["args"]["item_counts"] = {
key: value.to_dict() if isinstance(value, FieldResolver) else value
for key, value in output["args"]["item_counts"].items()
}
return output
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = data.get("args", {})
item_counts = cls._parse_field_resolvers(args.get("item_counts", {}), world_cls.game)
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(item_counts, options=options, filtered_resolution=data.get("filtered_resolution", False))
@override
def __str__(self) -> str:
items = ", ".join([f"{item} x{count}" for item, count in self.item_counts.items()])
@@ -1204,13 +1256,13 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
item_names: tuple[str, ...]
"""A tuple of item names to check for"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
def __init__(
self,
*item_names: str,
count: int = 1,
count: int | FieldResolver = 1,
options: Iterable[OptionFilter] = (),
filtered_resolution: bool = False,
) -> None:
@@ -1227,7 +1279,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
return Has(self.item_names[0], self.count).resolve(world)
return self.Resolved(
self.item_names,
self.count,
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@@ -1235,7 +1287,7 @@ class HasFromList(Rule[TWorld], game="Archipelago"):
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = {**data.get("args", {})}
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
item_names = args.pop("item_names", ())
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False))
@@ -1338,13 +1390,13 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
item_names: tuple[str, ...]
"""A tuple of item names to check for"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
def __init__(
self,
*item_names: str,
count: int = 1,
count: int | FieldResolver = 1,
options: Iterable[OptionFilter] = (),
filtered_resolution: bool = False,
) -> None:
@@ -1354,14 +1406,15 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
@override
def _instantiate(self, world: TWorld) -> Rule.Resolved:
if len(self.item_names) == 0 or len(self.item_names) < self.count:
count = resolve_field(self.count, world, int)
if len(self.item_names) == 0 or len(self.item_names) < count:
# match state.has_from_list_unique
return False_().resolve(world)
if len(self.item_names) == 1:
return Has(self.item_names[0]).resolve(world)
return self.Resolved(
self.item_names,
self.count,
count,
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@@ -1369,7 +1422,7 @@ class HasFromListUnique(Rule[TWorld], game="Archipelago"):
@override
@classmethod
def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self:
args = {**data.get("args", {})}
args = cls._parse_field_resolvers(data.get("args", {}), world_cls.game)
item_names = args.pop("item_names", ())
options = OptionFilter.multiple_from_dict(data.get("options", ()))
return cls(*item_names, **args, options=options, filtered_resolution=data.get("filtered_resolution", False))
@@ -1468,7 +1521,7 @@ class HasGroup(Rule[TWorld], game="Archipelago"):
item_name_group: str
"""The name of the item group containing the items"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
@override
@@ -1477,14 +1530,14 @@ class HasGroup(Rule[TWorld], game="Archipelago"):
return self.Resolved(
self.item_name_group,
item_names,
self.count,
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def __str__(self) -> str:
count = f", count={self.count}" if self.count > 1 else ""
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
options = f", options={self.options}" if self.options else ""
return f"{self.__class__.__name__}({self.item_name_group}{count}{options})"
@@ -1542,7 +1595,7 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"):
item_name_group: str
"""The name of the item group containing the items"""
count: int = 1
count: int | FieldResolver = 1
"""The number of items the player needs to have"""
@override
@@ -1551,14 +1604,14 @@ class HasGroupUnique(Rule[TWorld], game="Archipelago"):
return self.Resolved(
self.item_name_group,
item_names,
self.count,
count=resolve_field(self.count, world, int),
player=world.player,
caching_enabled=getattr(world, "rule_caching_enabled", False),
)
@override
def __str__(self) -> str:
count = f", count={self.count}" if self.count > 1 else ""
count = f", count={self.count}" if isinstance(self.count, FieldResolver) or self.count > 1 else ""
options = f", options={self.options}" if self.options else ""
return f"{self.__class__.__name__}({self.item_name_group}{count}{options})"