# based largely on https://github.com/ArchipelagoMW/Archipelago/blob/main/worlds/oot/RuleParser.py # the way the regions are set up in json files follows the way they set theirs up, so we parse them similarly import ast from collections import defaultdict from inspect import signature, Parameter import logging import re from .data.ItemList import item_table, item_groups from BaseClasses import CollectionState from .Locations import PMLocation from .Utils import load_json_data from worlds.generic.Rules import set_rule escaped_items = {} for item in item_table: escaped_items[re.sub(r'[\'()[\]]', '', item.replace(' ', '_'))] = item event_name = re.compile(r'\w+') kwarg_defaults = {} allowed_globals = {} rule_aliases = {} nonaliases = set() def load_aliases(): j = load_json_data("LogicHelpers.json") for s, repl in j.items(): if '(' in s: rule, args = s[:-1].split('(', 1) args = [re.compile(r'\b%s\b' % a.strip()) for a in args.split(',')] else: rule = s args = () rule_aliases[rule] = (args, repl) nonaliases = escaped_items.keys() - rule_aliases.keys() def isliteral(expr): return isinstance(expr, (ast.Num, ast.Str, ast.Bytes, ast.NameConstant)) class Rule_AST_Transformer(ast.NodeTransformer): def __init__(self, world, player): self.multiworld = world self.player = player self.events = set() # map Region -> rule ast string -> item name self.replaced_rules = defaultdict(dict) # delayed rules need to keep: region name, ast node, event name self.delayed_rules = [] # lazy load aliases if not rule_aliases: load_aliases() # final rule cache self.rule_cache = {} self.kwarg_defaults = kwarg_defaults.copy() # otherwise this gets contaminated between players self.kwarg_defaults['player'] = self.player def visit_Name(self, node): if node.id in dir(self): return getattr(self, node.id)(node) elif node.id in rule_aliases: args, repl = rule_aliases[node.id] if args: raise Exception('Parse Error: expected %d args for %s, not 0' % (len(args), node.id), self.current_spot.name, ast.dump(node, False)) return self.visit(ast.parse(repl, mode='eval').body) elif node.id in escaped_items: return ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='has', ctx=ast.Load()), args=[ast.Str(escaped_items[node.id]), ast.Constant(self.player)], keywords=[]) elif node.id in self.multiworld.options.__dict__: # Settings are constant return ast.parse('%r' % self.multiworld.options.__dict__[node.id].value, mode='eval').body elif node.id in self.multiworld.__dict__: return ast.parse('%r' % self.multiworld.__dict__[node.id], mode='eval').body elif node.id in CollectionState.__dict__: return self.make_call(node, node.id, [], []) elif node.id in self.kwarg_defaults or node.id in allowed_globals: return node elif event_name.match(node.id): self.events.add(node.id.replace('_', ' ')) return ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='has', ctx=ast.Load()), args=[ast.Str(node.id.replace('_', ' ')), ast.Constant(self.player)], keywords=[]) else: raise Exception('Parse Error: invalid node name %s' % node.id, self.current_spot.name, ast.dump(node, False)) def visit_Str(self, node): return ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='has', ctx=ast.Load()), args=[ast.Str(node.s), ast.Constant(self.player)], keywords=[]) # python 3.8 compatibility: ast walking now uses visit_Constant for Constant subclasses # this includes Num, Str, NameConstant, Bytes, and Ellipsis. We only handle Str. def visit_Constant(self, node): if isinstance(node, ast.Str): return self.visit_Str(node) return node def visit_Tuple(self, node): if len(node.elts) != 2: raise Exception('Parse Error: Tuple must have 2 values', self.current_spot.name, ast.dump(node, False)) item, count = node.elts if not isinstance(item, (ast.Name, ast.Str)): raise Exception('Parse Error: first value must be an item. Got %s' % item.__class__.__name__, self.current_spot.name, ast.dump(node, False)) iname = item.id if isinstance(item, ast.Name) else item.s if not (isinstance(count, ast.Name) or isinstance(count, ast.Num)): raise Exception('Parse Error: second value must be a number. Got %s' % item.__class__.__name__, self.current_spot.name, ast.dump(node, False)) if isinstance(count, ast.Name): # Must be a settings constant count = ast.parse('%r' % self.multiworld.options.__dict__[count.id].value, mode='eval').body if iname in escaped_items: iname = escaped_items[iname] if iname not in item_table and iname not in item_groups: self.events.add(iname) if iname in item_groups: return ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='has_group', ctx=ast.Load()), args=[ast.Str(iname), ast.Constant(self.player), count], keywords=[]) else: return ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='has', ctx=ast.Load()), args=[ast.Str(iname), ast.Constant(self.player), count], keywords=[]) def visit_Call(self, node): if not isinstance(node.func, ast.Name): return node if node.func.id in dir(self): return getattr(self, node.func.id)(node) elif node.func.id in rule_aliases: args, repl = rule_aliases[node.func.id] if len(args) != len(node.args): raise Exception( 'Parse Error: expected %d args for %s, not %d' % (len(args), node.func.id, len(node.args)), self.current_spot.name, ast.dump(node, False)) # straightforward string manip for arg_re, arg_val in zip(args, node.args): if isinstance(arg_val, ast.Name): val = arg_val.id elif isinstance(arg_val, ast.Constant): val = repr(arg_val.value) elif isinstance(arg_val, ast.Str): val = repr(arg_val.s) else: raise Exception('Parse Error: invalid argument %s' % ast.dump(arg_val, False), self.current_spot.name, ast.dump(node, False)) repl = arg_re.sub(val, repl) return self.visit(ast.parse(repl, mode='eval').body) new_args = [] for child in node.args: if isinstance(child, ast.Name): if child.id in self.multiworld.__dict__: # child = ast.Attribute( # value=ast.Attribute( # value=ast.Name(id='state', ctx=ast.Load()), # attr='world', # ctx=ast.Load()), # attr=child.id, # ctx=ast.Load()) child = ast.Constant(getattr(self.multiworld, child.id)) elif child.id in rule_aliases: child = self.visit(child) elif child.id in escaped_items: child = ast.Str(escaped_items[child.id]) else: child = ast.Str(child.id.replace('_', ' ')) elif not isinstance(child, ast.Str): child = self.visit(child) new_args.append(child) return self.make_call(node, node.func.id, new_args, node.keywords) def visit_Subscript(self, node): if isinstance(node.value, ast.Name): s = node.slice if isinstance(node.slice, ast.Name) else node.slice.value return ast.Subscript( value=ast.Attribute( # value=ast.Attribute( # value=ast.Name(id='state', ctx=ast.Load()), # attr='world', # ctx=ast.Load()), value=ast.Subscript( value=ast.Attribute( value=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='multiworld', ctx=ast.Load()), attr='worlds', ctx=ast.Load()), slice=ast.Index(value=ast.Constant(self.player)), ctx=ast.Load()), attr=node.value.id, ctx=ast.Load()), slice=ast.Index(value=ast.Str(s.id.replace('_', ' '))), ctx=node.ctx) else: return node def visit_Compare(self, node): def escape_or_string(n): if isinstance(n, ast.Name) and n.id in escaped_items: return ast.Str(escaped_items[n.id]) elif not isinstance(n, ast.Str): return self.visit(n) return n # Fast check for json can_use if (len(node.ops) == 1 and isinstance(node.ops[0], ast.Eq) and isinstance(node.left, ast.Name) and isinstance(node.comparators[0], ast.Name) and node.left.id not in self.multiworld.__dict__ and node.comparators[ 0].id not in self.multiworld.__dict__): return ast.NameConstant(node.left.id == node.comparators[0].id) node.left = escape_or_string(node.left) node.comparators = list(map(escape_or_string, node.comparators)) node.ops = list(map(self.visit, node.ops)) # if all the children are literals now, we can evaluate if isliteral(node.left) and all(map(isliteral, node.comparators)): # either we turn the ops into operator functions to apply (lots of work), # or we compile, eval, and reparse the result try: res = eval(compile(ast.fix_missing_locations(ast.Expression(node)), '', 'eval')) except TypeError as e: raise Exception('Parse Error: %s' % e, self.current_spot.name, ast.dump(node, False)) return self.visit(ast.parse('%r' % res, mode='eval').body) return node def visit_UnaryOp(self, node): # visit the children first self.generic_visit(node) # if all the children are literals now, we can evaluate if isliteral(node.operand): res = eval(compile(ast.Expression(node), '', 'eval')) return ast.parse('%r' % res, mode='eval').body return node def visit_BinOp(self, node): # visit the children first self.generic_visit(node) # if all the children are literals now, we can evaluate if isliteral(node.left) and isliteral(node.right): res = eval(compile(ast.Expression(node), '', 'eval')) return ast.parse('%r' % res, mode='eval').body return node def visit_BoolOp(self, node): # Everything else must be visited, then can be removed/reduced to. early_return = isinstance(node.op, ast.Or) groupable = 'has_any' if early_return else 'has_all' items = set() new_values = [] # if any elt is True(And)/False(Or), we can omit it # if any is False(And)/True(Or), the whole node can be replaced with it for elt in list(node.values): if isinstance(elt, ast.Str): items.add(elt.s) elif isinstance(elt, ast.Name) and elt.id in nonaliases: items.add(escaped_items[elt.id]) else: # It's possible this returns a single item check, # but it's already wrapped in a Call. elt = self.visit(elt) if isinstance(elt, ast.NameConstant): if elt.value == early_return: return elt # else omit it elif (isinstance(elt, ast.Call) and isinstance(elt.func, ast.Attribute) and elt.func.attr in ('has', groupable) and len(elt.args) == 1): args = elt.args[0] if isinstance(args, ast.Str): items.add(args.s) else: items.update(it.s for it in args.elts) elif isinstance(elt, ast.BoolOp) and node.op.__class__ == elt.op.__class__: new_values.extend(elt.values) else: new_values.append(elt) # package up the remaining items and values if not items and not new_values: # all values were True(And)/False(Or) return ast.NameConstant(not early_return) if items: node.values = [ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='has_any' if early_return else 'has_all', ctx=ast.Load()), args=[ast.Tuple(elts=[ast.Str(i) for i in items], ctx=ast.Load()), ast.Constant(self.player)], keywords=[])] + new_values else: node.values = new_values if len(node.values) == 1: return node.values[0] return node # Generates an ast.Call invoking the given State function 'name', # providing given args and keywords, and adding in additional def make_call(self, node, name, args, keywords): if not hasattr(CollectionState, name): raise Exception('Parse Error: No such function State.%s' % name, self.current_spot.name, ast.dump(node, False)) for (k, v) in self.kwarg_defaults.items(): keywords.append(ast.keyword(arg=f'{k}', value=ast.Constant(v))) return ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr=name, ctx=ast.Load()), args=args, keywords=keywords) def replace_subrule(self, target, node): rule = ast.dump(node, False) if rule in self.replaced_rules[target]: return self.replaced_rules[target][rule] subrule_name = target + ' Subrule %d' % (1 + len(self.replaced_rules[target])) # Save the info to be made into a rule later self.delayed_rules.append((target, node, subrule_name)) # Replace the call with a reference to that item item_rule = ast.Call( func=ast.Attribute( value=ast.Name(id='state', ctx=ast.Load()), attr='has', ctx=ast.Load()), args=[ast.Str(subrule_name), ast.Constant(self.player)], keywords=[]) # Cache the subrule for any others in this region # (and reserve the item name in the process) self.replaced_rules[target][rule] = item_rule return item_rule # Requires the target regions have been defined in the world. def create_delayed_rules(self): for region_name, node, subrule_name in self.delayed_rules: region = self.multiworld.multiworld.get_region(region_name, self.player) event = PMLocation(self.player, subrule_name, type='Event', parent=region, internal=True) event.show_in_spoiler = False self.current_spot = event # This could, in theory, create further subrules. access_rule = self.make_access_rule(self.visit(node)) if access_rule is self.rule_cache.get('NameConstant(False)'): event.access_rule = None event.never = True logging.getLogger('').debug('Dropping unreachable delayed event: %s', event.name) else: if access_rule is self.rule_cache.get('NameConstant(True)'): event.always = True set_rule(event, access_rule) region.locations.append(event) self.multiworld.make_event_item(subrule_name, event) # Safeguard in case this is called multiple times per world self.delayed_rules.clear() def make_access_rule(self, body): rule_str = ast.dump(body, False) if rule_str not in self.rule_cache: # requires consistent iteration on dicts kwargs = [ast.arg(arg=k) for k in self.kwarg_defaults.keys()] kwd = list(map(ast.Constant, self.kwarg_defaults.values())) try: self.rule_cache[rule_str] = eval(compile( ast.fix_missing_locations( ast.Expression(ast.Lambda( args=ast.arguments( posonlyargs=[], args=[ast.arg(arg='state')], defaults=[], kwonlyargs=kwargs, kw_defaults=kwd), body=body))), '', 'eval'), # globals/locals. if undefined, everything in the namespace *now* would be allowed allowed_globals) except TypeError as e: raise Exception('Parse Error: %s' % e, self.current_spot.name, ast.dump(body, False)) return self.rule_cache[rule_str] ## Handlers for specific internal functions used in the json logic. # at(region_name, rule) # Creates an internal event at the remote region and depends on it. def at(self, node): # Cache this under the target (region) name if len(node.args) < 2 or not isinstance(node.args[0], ast.Str): raise Exception('Parse Error: invalid at() arguments', self.current_spot.name, ast.dump(node, False)) return self.replace_subrule(node.args[0].s, node.args[1]) # here(rule) # Creates an internal event in the same region and depends on it. def here(self, node): if not node.args: raise Exception('Parse Error: missing here() argument', self.current_spot.name, ast.dump(node, False)) return self.replace_subrule( self.current_spot.parent_region.name, node.args[0]) # Parse entry point # If spot is None, here() rules won't work. def parse_rule(self, rule_string, spot=None): self.current_spot = spot return self.make_access_rule(self.visit(ast.parse(rule_string, mode='eval').body)) def parse_spot_rule(self, spot): rule = spot.rule_string.split('#', 1)[0].strip() access_rule = self.parse_rule(rule, spot) set_rule(spot, access_rule) if access_rule is self.rule_cache.get('NameConstant(False)'): spot.never = True elif access_rule is self.rule_cache.get('NameConstant(True)'): spot.always = True