From 5e8a2199cb34a0e026c9fda83adcd6b9acca7054 Mon Sep 17 00:00:00 2001 From: spinerak Date: Fri, 14 Jun 2024 19:15:06 +0200 Subject: [PATCH] Revert "Use numpy and pmf function to speed up gen" This reverts commit 9290191cb323ae92321d6c2cfcfe8c27370f439b. --- worlds/yachtdice/Rules.py | 87 ++++++++++++++------------------------- 1 file changed, 32 insertions(+), 55 deletions(-) diff --git a/worlds/yachtdice/Rules.py b/worlds/yachtdice/Rules.py index 5836e81f91..69f6a4bd67 100644 --- a/worlds/yachtdice/Rules.py +++ b/worlds/yachtdice/Rules.py @@ -7,7 +7,6 @@ from BaseClasses import MultiWorld from worlds.generic.Rules import set_rule from .YachtWeights import yacht_weights -import numpy as np # This module adds logic to the apworld. @@ -70,6 +69,8 @@ class Category: return mean_score * self.quantity + + class ListState: def __init__(self, state: List[str]): self.state = state @@ -98,25 +99,19 @@ def extract_progression(state, player, options): ) number_of_fixed_mults = state.count("Fixed Score Multiplier", player) number_of_step_mults = state.count("Step Score Multiplier", player) - + categories = [ Category(category_value, state.count(category_name, player)) for category_name, category_value in category_mappings.items() if state.count(category_name, player) # want all categories that have count >= 1 - ] - + ] + extra_points_in_logic = state.count("1 Point", player) extra_points_in_logic += state.count("10 Points", player) * 10 extra_points_in_logic += state.count("100 Points", player) * 100 - return ( - categories, - number_of_dice, - number_of_rerolls, - number_of_fixed_mults * 0.1, - number_of_step_mults * 0.01, - extra_points_in_logic, - ) + return categories, number_of_dice, number_of_rerolls, number_of_fixed_mults * 0.1, number_of_step_mults * 0.01, extra_points_in_logic, + # We will store the results of this function as it is called often for the same parameters. @@ -145,44 +140,18 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu # sort categories because for the step multiplier, you will want low-scoring categories first categories.sort(key=lambda category: category.mean_score(num_dice, num_rolls)) - # we have two ways to store a distribution (example, 0 with probability 0.4, 10 with probability 0.6): - # dict: {0: 0.4, 10: 0.6} - # pmf (probability mass function): [0.4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.6] (numpy array) - # adding two distributions works fast with pmf's and numpy's convolve method. - # maximizing two distributions (with multipliers) seems to work fastest with dictionaries. - - def dict_to_pmf(dist): - """ - Convert dict-distribution to pmf-distribution - """ - max_value = max(dist) + 1 - return np.array([dist.get(i, 0) for i in range(max_value)]) - - def pmf_to_dict(pmf): - """ - Convert pmf-distribution to dict-distribution - """ - sum_values = np.arange(0, len(pmf) + 1) - sum_dist = {v: p for v, p in zip(sum_values, pmf)} - return sum_dist - - def add_distributions(pmf1, dic2): - """ - function to add two discrete distributions. The first in pmf form, the second in dict form, returns pmf. - """ - pmf2 = dict_to_pmf(dic2) - - # Sum the two distributions using convolution - sum_pmf = np.convolve(pmf1, pmf2) - - return sum_pmf + # function to add two discrete distribution. + # defaultdict is a dict where you don't need to check if an id is present, you can just use += (lot faster) + def add_distributions(dist1, dist2): + combined_dist = defaultdict(float) + for val1, prob1 in dist1.items(): + for val2, prob2 in dist2.items(): + combined_dist[val1 + val2] += prob1 * prob2 + return dict(combined_dist) + # function to take the maximum of "times" i.i.d. dist1. + # (I have tried using defaultdict here too but this made it slower.) def max_dist(dist1, mults): - """ - function to take the maximum of "times" i.i.d. dist1. - dist1 is a dict-distribution - (I have tried using defaultdict here too but this made it slower.) - """ new_dist = {0: 1} for mult in mults: c = new_dist.copy() @@ -202,10 +171,16 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu # Returns percentile value of a distribution. def percentile_distribution(dist, percentile): - cumdist = np.cumsum(dist) + sorted_values = sorted(dist.keys()) + cumulative_prob = 0 + + for val in sorted_values: + cumulative_prob += dist[val] + if cumulative_prob >= percentile: + return val # Return the last value if percentile is higher than all probabilities - return np.argmax(cumdist > percentile) + return sorted_values[-1] # parameters for logic. # perc_return is, per difficulty, the percentages of total score it returns (it averages out the values) @@ -213,8 +188,8 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu perc_return = [[0], [0.1, 0.5], [0.3, 0.7], [0.55, 0.85], [0.85, 0.95]][diff] diff_divide = [0, 9, 7, 3, 2][diff] - # calculate total distribution, start in pmf-form - total_dist = [1] + # calculate total distribution + total_dist = {0: 1} for j, category in enumerate(categories): if num_dice == 0 or num_rolls == 0: dist = {0: 100000} @@ -233,14 +208,14 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu total_dist = add_distributions(total_dist, dist) - # note, total_dist is in pmf-form - outcome = sum([percentile_distribution(total_dist, perc) for perc in perc_return]) / len(perc_return) # save result into the cache, then return it + outcome = sum([percentile_distribution(total_dist, perc) for perc in perc_return]) / len(perc_return) yachtdice_cache[tup] = max(5, math.floor(outcome)) # at least 5. - + return yachtdice_cache[tup] + def dice_simulation(state, player, options): """ Returns the feasible score that one can reach with the current state, options and difficulty. @@ -268,6 +243,7 @@ def dice_simulation(state, player, options): return state.prog_items[player]["maximum_achievable_score"] + def set_yacht_rules(world: MultiWorld, player: int, options): """ Sets rules on entrances and advancements that are always applied @@ -286,3 +262,4 @@ def set_yacht_completion_rules(world: MultiWorld, player: int): Sets rules on completion condition """ world.completion_condition[player] = lambda state: state.has("Victory", player) + \ No newline at end of file