From 4a28888a6610d24f7c7bfe6909826e418f8f6221 Mon Sep 17 00:00:00 2001 From: Ishigh1 Date: Sat, 9 May 2026 16:56:10 +0200 Subject: [PATCH] Rule Builder: Implement AtLeast (#6085) --------- Co-authored-by: Ian Robinson --- docs/rule builder.md | 1 + rule_builder/rules.py | 136 +++++++++++++++++++++++++- test/general/test_rule_builder.py | 157 ++++++++++++++++++++++++------ 3 files changed, 263 insertions(+), 31 deletions(-) diff --git a/docs/rule builder.md b/docs/rule builder.md index 829ab763d7..8768e6447f 100644 --- a/docs/rule builder.md +++ b/docs/rule builder.md @@ -41,6 +41,7 @@ The rule builder comes with a number of rules by default: - `False_`: Always returns false - `And`: Checks that all child rules are true (also provided by `&` operator) - `Or`: Checks that at least one child rule is true (also provided by `|` operator) +- `AtLeast`: Checks that at least some count of rules is true - `Has`: Checks that the player has the given item with the given count (default 1) - `HasAll`: Checks that the player has all given items - `HasAny`: Checks that the player has at least one of the given items diff --git a/rule_builder/rules.py b/rule_builder/rules.py index 47f91aff5e..d940eeb386 100644 --- a/rule_builder/rules.py +++ b/rule_builder/rules.py @@ -425,13 +425,142 @@ class NestedRule(Rule[TWorld], game="Archipelago"): return combined_deps +class AtLeast(NestedRule[TWorld], game="Archipelago"): + """A rule that returns true when at least N child rules evaluate as true""" + + count: int | FieldResolver + + def __init__( + self, + count: int | FieldResolver, + *children: Rule[TWorld], + options: Iterable[OptionFilter] = (), + filtered_resolution: bool = False, + ) -> None: + super().__init__(*children, options=options, filtered_resolution=filtered_resolution) + self.count = count + + @override + def _instantiate(self, world: TWorld) -> Rule.Resolved: + count = resolve_field(self.count, world, int) + if count == 0: + return True_().resolve(world) + + children_to_process = [c.resolve(world) for c in self.children] + return AtLeast.from_resolved(count, world, children_to_process) + + @classmethod + def from_resolved(cls, count: int, world: TWorld, children_to_process: list[Rule.Resolved]) -> Rule.Resolved: + clauses: list[Rule.Resolved] = [] + + while children_to_process: + child = children_to_process.pop(0) + if child.always_true: + if count == 1: + return child + count -= 1 + continue + if child.always_false: + # falses can be ignored + continue + + clauses.append(child) + + if len(clauses) < count: + return False_().resolve(world) + if count == 1: + # Switch to Or which has more optimized handling + return Or.from_resolved(world, clauses) + if count == len(clauses): + # Switch to And which has more optimized handling + return And.from_resolved(world, clauses) + return AtLeast.Resolved( + tuple(clauses), + count=count, + player=world.player, + caching_enabled=getattr(world, "rule_caching_enabled", False), + ) + + @override + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + count = self.count + output["count"] = count.to_dict() if isinstance(count, FieldResolver) else count + return output + + @override + @classmethod + def from_dict(cls, data: Mapping[str, Any], world_cls: "type[World]") -> Self: + args = cls._parse_field_resolvers(data, world_cls.game) + options = OptionFilter.multiple_from_dict(data.get("options", ())) + children = [world_cls.rule_from_dict(c) for c in data.get("children", ())] + return cls( + args.pop("count"), + *children, + options=options, + filtered_resolution=data.get("filtered_resolution", False), + ) + + class Resolved(NestedRule.Resolved): + count: int + + @override + def _evaluate(self, state: CollectionState) -> bool: + count = self.count + for rule in self.children: + if rule(state): + if count == 1: + return True + count -= 1 + return False + + @override + def explain_json(self, state: CollectionState | None = None) -> list[JSONMessagePart]: + messages: list[JSONMessagePart] = [] + if state is None: + messages = [ + {"type": "text", "text": "At least "}, + {"type": "color", "color": "cyan", "text": str(self.count)}, + {"type": "text", "text": " of ("}, + ] + else: + satisfied_count = sum(1 if child(state) else 0 for child in self.children) + messages = [ + {"type": "text", "text": "At least "}, + {"type": "color", "color": "cyan", "text": f"{satisfied_count}/{self.count}"}, + {"type": "text", "text": " of ("}, + ] + for i, child in enumerate(self.children): + if i > 0: + messages.append({"type": "text", "text": ", "}) + messages.extend(child.explain_json(state)) + messages.append({"type": "text", "text": ")"}) + return messages + + @override + def explain_str(self, state: CollectionState | None = None) -> str: + clauses = ", ".join([c.explain_str(state) for c in self.children]) + if state is None: + return f"At least {self.count} of ({clauses})" + satisfied_count = sum(1 if child(state) else 0 for child in self.children) + return f"At least {satisfied_count}/{self.count} of ({clauses})" + + @override + def __str__(self) -> str: + clauses = ", ".join([str(c) for c in self.children]) + return f"At least {self.count} of ({clauses})" + + @dataclasses.dataclass(init=False) class And(NestedRule[TWorld], game="Archipelago"): """A rule that only returns true when all child rules evaluate as true""" @override def _instantiate(self, world: TWorld) -> Rule.Resolved: - children_to_process = [c.resolve(world) for c in self.children] + return And.from_resolved(world, [c.resolve(world) for c in self.children]) + + @classmethod + def from_resolved(cls, world: TWorld, children_to_process: list[Rule.Resolved]) -> Rule.Resolved: clauses: list[Rule.Resolved] = [] items: dict[str, int] = {} true_rule: Rule.Resolved | None = None @@ -518,7 +647,10 @@ class Or(NestedRule[TWorld], game="Archipelago"): @override def _instantiate(self, world: TWorld) -> Rule.Resolved: - children_to_process = [c.resolve(world) for c in self.children] + return Or.from_resolved(world, [c.resolve(world) for c in self.children]) + + @classmethod + def from_resolved(cls, world: TWorld, children_to_process: list[Rule.Resolved]) -> Rule.Resolved: clauses: list[Rule.Resolved] = [] items: dict[str, int] = {} diff --git a/test/general/test_rule_builder.py b/test/general/test_rule_builder.py index 682c043f8e..191ba3cba7 100644 --- a/test/general/test_rule_builder.py +++ b/test/general/test_rule_builder.py @@ -12,6 +12,7 @@ from rule_builder.field_resolvers import FieldResolver, FromOption, FromWorldAtt from rule_builder.options import Operator, OptionFilter from rule_builder.rules import ( And, + AtLeast, CanReachEntrance, CanReachLocation, CanReachRegion, @@ -250,6 +251,40 @@ class CachedRuleBuilderTestCase(RuleBuilderTestCase): Or(HasAnyCount({"A": 1, "B": 2}), HasAnyCount({"A": 2, "B": 2})), HasAnyCount.Resolved((("A", 1), ("B", 2)), player=1), ), + ( + AtLeast(0, Has("A")), + True_.Resolved(player=1), + ), + ( + AtLeast(3, True_(), Has("A"), Has("B"), Has("C")), + AtLeast.Resolved( + (Has.Resolved("A", player=1), Has.Resolved("B", player=1), Has.Resolved("C", player=1)), 2, player=1 + ), + ), + ( + AtLeast(2, False_(), Has("A"), Has("B"), Has("C")), + AtLeast.Resolved( + (Has.Resolved("A", player=1), Has.Resolved("B", player=1), Has.Resolved("C", player=1)), 2, player=1 + ), + ), + ( + AtLeast(2, True_(), True_(), Has("A")), + True_.Resolved(player=1), + ), + ( + AtLeast(3, Has("A"), Has("B")), + False_.Resolved(player=1), + ), + ( + # This test will fail when Or(Rule, Rule) will be optimized to Rule + AtLeast(1, Rule(), Rule()), + Or.Resolved((Rule.Resolved(player=1), Rule.Resolved(player=1)), player=1), + ), + ( + # This test will fail when And(Rule, Rule) will be optimized to Rule + AtLeast(2, Rule(), Rule()), + And.Resolved((Rule.Resolved(player=1), Rule.Resolved(player=1)), player=1), + ), ) ) class TestSimplify(RuleBuilderTestCase): @@ -631,6 +666,24 @@ class TestRules(RuleBuilderTestCase): self.state.remove(item) self.assertFalse(resolved_rule(self.state)) + def test_at_least(self) -> None: + # Has has to be relied on as True_ and False_ would be optimized out + rule = AtLeast(2, Has("Item 1"), Has("Item 1"), Has("Item 2"), Has("Item 3")) + resolved_rule = rule.resolve(self.world) + self.world.register_rule_dependencies(resolved_rule) + item1 = self.world.create_item("Item 1") + item2 = self.world.create_item("Item 2") + item3 = self.world.create_item("Item 3") + self.assertFalse(resolved_rule(self.state)) + self.state.collect(item1) + self.assertTrue(resolved_rule(self.state)) + self.state.collect(item2) + self.assertTrue(resolved_rule(self.state)) + self.state.remove(item1) + self.assertFalse(resolved_rule(self.state)) + self.state.collect(item3) + self.assertTrue(resolved_rule(self.state)) + def test_has_all(self) -> None: rule = HasAll("Item 1", "Item 2") resolved_rule = rule.resolve(self.world) @@ -806,8 +859,13 @@ class TestSerialization(RuleBuilderTestCase): OptionFilter(ChoiceOption, ChoiceOption.option_second, "ge"), ], ), + AtLeast( + FromWorldAttr("instance_data.at_least_requirement"), + Has("i15", count=2), + HasGroup("g2", count=3), + ), CanReachEntrance("e1"), - HasGroupUnique("g2", count=5), + HasGroupUnique("g3", count=5), ) rule_dict: ClassVar[dict[str, Any]] = { @@ -931,6 +989,29 @@ class TestSerialization(RuleBuilderTestCase): }, ], }, + { + "rule": "AtLeast", + "options": [], + "filtered_resolution": False, + "count": {"resolver": "FromWorldAttr", "name": "instance_data.at_least_requirement"}, + "children": [ + { + "rule": "Has", + "options": [], + "filtered_resolution": False, + "args": { + "item_name": "i15", + "count": 2, + }, + }, + { + "rule": "HasGroup", + "options": [], + "filtered_resolution": False, + "args": {"item_name_group": "g2", "count": 3}, + }, + ], + }, { "rule": "CanReachEntrance", "options": [], @@ -941,7 +1022,7 @@ class TestSerialization(RuleBuilderTestCase): "rule": "HasGroupUnique", "options": [], "filtered_resolution": False, - "args": {"item_name_group": "g2", "count": 5}, + "args": {"item_name_group": "g3", "count": 5}, }, ], } @@ -973,9 +1054,15 @@ class TestExplain(RuleBuilderTestCase): ), player=1, ), - HasAllCounts.Resolved((("Item 6", 1), ("Item 7", 5)), player=1), - HasAnyCount.Resolved((("Item 8", 2), ("Item 9", 3)), player=1), - HasFromList.Resolved(("Item 10", "Item 11", "Item 12"), count=2, player=1), + AtLeast.Resolved( + children=( + HasAllCounts.Resolved((("Item 6", 1), ("Item 7", 5)), player=1), + HasAnyCount.Resolved((("Item 8", 2), ("Item 9", 3)), player=1), + HasFromList.Resolved(("Item 10", "Item 11", "Item 12"), count=2, player=1), + ), + count=2, + player=1, + ), HasFromListUnique.Resolved(("Item 13", "Item 14"), player=1), HasGroup.Resolved("Group 1", ("Item 15", "Item 16", "Item 17"), player=1), HasGroupUnique.Resolved("Group 2", ("Item 18", "Item 19"), count=2, player=1), @@ -1040,6 +1127,9 @@ class TestExplain(RuleBuilderTestCase): {"type": "text", "text": ")"}, {"type": "text", "text": ")"}, {"type": "text", "text": " & "}, + {"type": "text", "text": "At least "}, + {"type": "color", "color": "cyan", "text": "0/2"}, + {"type": "text", "text": " of ("}, {"type": "text", "text": "Missing "}, {"type": "color", "color": "cyan", "text": "some"}, {"type": "text", "text": " of ("}, @@ -1050,7 +1140,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "color", "color": "salmon", "text": "Item 7"}, {"type": "text", "text": " x5"}, {"type": "text", "text": ")"}, - {"type": "text", "text": " & "}, + {"type": "text", "text": ", "}, {"type": "text", "text": "Missing "}, {"type": "color", "color": "cyan", "text": "all"}, {"type": "text", "text": " of ("}, @@ -1061,7 +1151,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "color", "color": "salmon", "text": "Item 9"}, {"type": "text", "text": " x3"}, {"type": "text", "text": ")"}, - {"type": "text", "text": " & "}, + {"type": "text", "text": ", "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "salmon", "text": "0/2"}, {"type": "text", "text": " items from ("}, @@ -1072,6 +1162,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "text", "text": ", "}, {"type": "color", "color": "salmon", "text": "Item 12"}, {"type": "text", "text": ")"}, + {"type": "text", "text": ")"}, {"type": "text", "text": " & "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "salmon", "text": "0/1"}, @@ -1138,6 +1229,9 @@ class TestExplain(RuleBuilderTestCase): {"type": "text", "text": ")"}, {"type": "text", "text": ")"}, {"type": "text", "text": " & "}, + {"type": "text", "text": "At least "}, + {"type": "color", "color": "cyan", "text": "3/2"}, + {"type": "text", "text": " of ("}, {"type": "text", "text": "Has "}, {"type": "color", "color": "cyan", "text": "all"}, {"type": "text", "text": " of ("}, @@ -1148,7 +1242,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "color", "color": "green", "text": "Item 7"}, {"type": "text", "text": " x5"}, {"type": "text", "text": ")"}, - {"type": "text", "text": " & "}, + {"type": "text", "text": ", "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "cyan", "text": "some"}, {"type": "text", "text": " of ("}, @@ -1159,7 +1253,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "color", "color": "green", "text": "Item 9"}, {"type": "text", "text": " x3"}, {"type": "text", "text": ")"}, - {"type": "text", "text": " & "}, + {"type": "text", "text": ", "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "green", "text": "30/2"}, {"type": "text", "text": " items from ("}, @@ -1170,6 +1264,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "text", "text": ", "}, {"type": "color", "color": "green", "text": "Item 12"}, {"type": "text", "text": ")"}, + {"type": "text", "text": ")"}, {"type": "text", "text": " & "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "green", "text": "2/1"}, @@ -1204,7 +1299,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "color", "color": "salmon", "text": "False"}, {"type": "text", "text": ")"}, ] - assert self.resolved_rule.explain_json(self.state) == expected + self.assertEqual(self.resolved_rule.explain_json(self.state), expected) def test_explain_json_without_state(self) -> None: expected: list[JSONMessagePart] = [ @@ -1232,6 +1327,9 @@ class TestExplain(RuleBuilderTestCase): {"type": "text", "text": ")"}, {"type": "text", "text": ")"}, {"type": "text", "text": " & "}, + {"type": "text", "text": "At least "}, + {"type": "color", "color": "cyan", "text": "2"}, + {"type": "text", "text": " of ("}, {"type": "text", "text": "Has "}, {"type": "color", "color": "cyan", "text": "all"}, {"type": "text", "text": " of ("}, @@ -1241,7 +1339,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "item_name", "flags": 1, "text": "Item 7", "player": 1}, {"type": "text", "text": " x5"}, {"type": "text", "text": ")"}, - {"type": "text", "text": " & "}, + {"type": "text", "text": ", "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "cyan", "text": "any"}, {"type": "text", "text": " of ("}, @@ -1251,7 +1349,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "item_name", "flags": 1, "text": "Item 9", "player": 1}, {"type": "text", "text": " x3"}, {"type": "text", "text": ")"}, - {"type": "text", "text": " & "}, + {"type": "text", "text": ", "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "cyan", "text": "2"}, {"type": "text", "text": "x items from ("}, @@ -1261,6 +1359,7 @@ class TestExplain(RuleBuilderTestCase): {"type": "text", "text": ", "}, {"type": "item_name", "flags": 1, "text": "Item 12", "player": 1}, {"type": "text", "text": ")"}, + {"type": "text", "text": ")"}, {"type": "text", "text": " & "}, {"type": "text", "text": "Has "}, {"type": "color", "color": "cyan", "text": "1"}, @@ -1294,16 +1393,16 @@ class TestExplain(RuleBuilderTestCase): {"type": "color", "color": "salmon", "text": "False"}, {"type": "text", "text": ")"}, ] - assert self.resolved_rule.explain_json() == expected + self.assertEqual(self.resolved_rule.explain_json(), expected) def test_explain_str_with_state_no_items(self) -> None: expected = ( "((Missing 4x Item 1", "| Missing some of (Missing: Item 2, Item 3)", "| Missing all of (Missing: Item 4, Item 5))", - "& Missing some of (Missing: Item 6 x1, Item 7 x5)", - "& Missing all of (Missing: Item 8 x2, Item 9 x3)", - "& Has 0/2 items from (Missing: Item 10, Item 11, Item 12)", + "& At least 0/2 of (Missing some of (Missing: Item 6 x1, Item 7 x5),", + "Missing all of (Missing: Item 8 x2, Item 9 x3),", + "Has 0/2 items from (Missing: Item 10, Item 11, Item 12))", "& Has 0/1 unique items from (Missing: Item 13, Item 14)", "& Has 0/1 items from Group 1", "& Has 0/2 unique items from Group 2", @@ -1313,7 +1412,7 @@ class TestExplain(RuleBuilderTestCase): "& True", "& False)", ) - assert self.resolved_rule.explain_str(self.state) == " ".join(expected) + self.assertEqual(self.resolved_rule.explain_str(self.state), " ".join(expected)) def test_explain_str_with_state_all_items(self) -> None: self._collect_all() @@ -1322,9 +1421,9 @@ class TestExplain(RuleBuilderTestCase): "((Has 4x Item 1", "| Has all of (Found: Item 2, Item 3)", "| Has some of (Found: Item 4, Item 5))", - "& Has all of (Found: Item 6 x1, Item 7 x5)", - "& Has some of (Found: Item 8 x2, Item 9 x3)", - "& Has 30/2 items from (Found: Item 10, Item 11, Item 12)", + "& At least 3/2 of (Has all of (Found: Item 6 x1, Item 7 x5),", + "Has some of (Found: Item 8 x2, Item 9 x3),", + "Has 30/2 items from (Found: Item 10, Item 11, Item 12))", "& Has 2/1 unique items from (Found: Item 13, Item 14)", "& Has 30/1 items from Group 1", "& Has 2/2 unique items from Group 2", @@ -1334,16 +1433,16 @@ class TestExplain(RuleBuilderTestCase): "& True", "& False)", ) - assert self.resolved_rule.explain_str(self.state) == " ".join(expected) + self.assertEqual(self.resolved_rule.explain_str(self.state), " ".join(expected)) def test_explain_str_without_state(self) -> None: expected = ( "((Has 4x Item 1", "| Has all of (Item 2, Item 3)", "| Has any of (Item 4, Item 5))", - "& Has all of (Item 6 x1, Item 7 x5)", - "& Has any of (Item 8 x2, Item 9 x3)", - "& Has 2x items from (Item 10, Item 11, Item 12)", + "& At least 2 of (Has all of (Item 6 x1, Item 7 x5),", + "Has any of (Item 8 x2, Item 9 x3),", + "Has 2x items from (Item 10, Item 11, Item 12))", "& Has a unique item from (Item 13, Item 14)", "& Has an item from Group 1", "& Has 2x unique items from Group 2", @@ -1353,16 +1452,16 @@ class TestExplain(RuleBuilderTestCase): "& True", "& False)", ) - assert self.resolved_rule.explain_str() == " ".join(expected) + self.assertEqual(self.resolved_rule.explain_str(), " ".join(expected)) def test_str(self) -> None: expected = ( "((Has 4x Item 1", "| Has all of (Item 2, Item 3)", "| Has any of (Item 4, Item 5))", - "& Has all of (Item 6 x1, Item 7 x5)", - "& Has any of (Item 8 x2, Item 9 x3)", - "& Has 2x items from (Item 10, Item 11, Item 12)", + "& At least 2 of (Has all of (Item 6 x1, Item 7 x5),", + "Has any of (Item 8 x2, Item 9 x3),", + "Has 2x items from (Item 10, Item 11, Item 12))", "& Has a unique item from (Item 13, Item 14)", "& Has an item from Group 1", "& Has 2x unique items from Group 2", @@ -1372,7 +1471,7 @@ class TestExplain(RuleBuilderTestCase): "& True", "& False)", ) - assert str(self.resolved_rule) == " ".join(expected) + self.assertEqual(str(self.resolved_rule), " ".join(expected)) @classvar_matrix(