mirror of
https://github.com/ArchipelagoMW/Archipelago.git
synced 2026-03-07 15:13:52 -08:00
Use numpy and pmf function to speed up gen
Numpy has a built-in way to sum probability mass functions (pmf). This shaves of 60% of the generation time :D
This commit is contained in:
@@ -7,6 +7,7 @@ from BaseClasses import MultiWorld
|
||||
from worlds.generic.Rules import set_rule
|
||||
|
||||
from .YachtWeights import yacht_weights
|
||||
import numpy as np
|
||||
|
||||
|
||||
# This module adds logic to the apworld.
|
||||
@@ -69,8 +70,6 @@ class Category:
|
||||
return mean_score * self.quantity
|
||||
|
||||
|
||||
|
||||
|
||||
class ListState:
|
||||
def __init__(self, state: List[str]):
|
||||
self.state = state
|
||||
@@ -99,19 +98,25 @@ def extract_progression(state, player, options):
|
||||
)
|
||||
number_of_fixed_mults = state.count("Fixed Score Multiplier", player)
|
||||
number_of_step_mults = state.count("Step Score Multiplier", player)
|
||||
|
||||
|
||||
categories = [
|
||||
Category(category_value, state.count(category_name, player))
|
||||
for category_name, category_value in category_mappings.items()
|
||||
if state.count(category_name, player) # want all categories that have count >= 1
|
||||
]
|
||||
|
||||
]
|
||||
|
||||
extra_points_in_logic = state.count("1 Point", player)
|
||||
extra_points_in_logic += state.count("10 Points", player) * 10
|
||||
extra_points_in_logic += state.count("100 Points", player) * 100
|
||||
|
||||
return categories, number_of_dice, number_of_rerolls, number_of_fixed_mults * 0.1, number_of_step_mults * 0.01, extra_points_in_logic,
|
||||
|
||||
return (
|
||||
categories,
|
||||
number_of_dice,
|
||||
number_of_rerolls,
|
||||
number_of_fixed_mults * 0.1,
|
||||
number_of_step_mults * 0.01,
|
||||
extra_points_in_logic,
|
||||
)
|
||||
|
||||
|
||||
# We will store the results of this function as it is called often for the same parameters.
|
||||
@@ -140,18 +145,44 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu
|
||||
# sort categories because for the step multiplier, you will want low-scoring categories first
|
||||
categories.sort(key=lambda category: category.mean_score(num_dice, num_rolls))
|
||||
|
||||
# function to add two discrete distribution.
|
||||
# defaultdict is a dict where you don't need to check if an id is present, you can just use += (lot faster)
|
||||
def add_distributions(dist1, dist2):
|
||||
combined_dist = defaultdict(float)
|
||||
for val1, prob1 in dist1.items():
|
||||
for val2, prob2 in dist2.items():
|
||||
combined_dist[val1 + val2] += prob1 * prob2
|
||||
return dict(combined_dist)
|
||||
# we have two ways to store a distribution (example, 0 with probability 0.4, 10 with probability 0.6):
|
||||
# dict: {0: 0.4, 10: 0.6}
|
||||
# pmf (probability mass function): [0.4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.6] (numpy array)
|
||||
# adding two distributions works fast with pmf's and numpy's convolve method.
|
||||
# maximizing two distributions (with multipliers) seems to work fastest with dictionaries.
|
||||
|
||||
def dict_to_pmf(dist):
|
||||
"""
|
||||
Convert dict-distribution to pmf-distribution
|
||||
"""
|
||||
max_value = max(dist) + 1
|
||||
return np.array([dist.get(i, 0) for i in range(max_value)])
|
||||
|
||||
def pmf_to_dict(pmf):
|
||||
"""
|
||||
Convert pmf-distribution to dict-distribution
|
||||
"""
|
||||
sum_values = np.arange(0, len(pmf) + 1)
|
||||
sum_dist = {v: p for v, p in zip(sum_values, pmf)}
|
||||
return sum_dist
|
||||
|
||||
def add_distributions(pmf1, dic2):
|
||||
"""
|
||||
function to add two discrete distributions. The first in pmf form, the second in dict form, returns pmf.
|
||||
"""
|
||||
pmf2 = dict_to_pmf(dic2)
|
||||
|
||||
# Sum the two distributions using convolution
|
||||
sum_pmf = np.convolve(pmf1, pmf2)
|
||||
|
||||
return sum_pmf
|
||||
|
||||
# function to take the maximum of "times" i.i.d. dist1.
|
||||
# (I have tried using defaultdict here too but this made it slower.)
|
||||
def max_dist(dist1, mults):
|
||||
"""
|
||||
function to take the maximum of "times" i.i.d. dist1.
|
||||
dist1 is a dict-distribution
|
||||
(I have tried using defaultdict here too but this made it slower.)
|
||||
"""
|
||||
new_dist = {0: 1}
|
||||
for mult in mults:
|
||||
c = new_dist.copy()
|
||||
@@ -171,16 +202,10 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu
|
||||
|
||||
# Returns percentile value of a distribution.
|
||||
def percentile_distribution(dist, percentile):
|
||||
sorted_values = sorted(dist.keys())
|
||||
cumulative_prob = 0
|
||||
|
||||
for val in sorted_values:
|
||||
cumulative_prob += dist[val]
|
||||
if cumulative_prob >= percentile:
|
||||
return val
|
||||
cumdist = np.cumsum(dist)
|
||||
|
||||
# Return the last value if percentile is higher than all probabilities
|
||||
return sorted_values[-1]
|
||||
return np.argmax(cumdist > percentile)
|
||||
|
||||
# parameters for logic.
|
||||
# perc_return is, per difficulty, the percentages of total score it returns (it averages out the values)
|
||||
@@ -188,8 +213,8 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu
|
||||
perc_return = [[0], [0.1, 0.5], [0.3, 0.7], [0.55, 0.85], [0.85, 0.95]][diff]
|
||||
diff_divide = [0, 9, 7, 3, 2][diff]
|
||||
|
||||
# calculate total distribution
|
||||
total_dist = {0: 1}
|
||||
# calculate total distribution, start in pmf-form
|
||||
total_dist = [1]
|
||||
for j, category in enumerate(categories):
|
||||
if num_dice == 0 or num_rolls == 0:
|
||||
dist = {0: 100000}
|
||||
@@ -208,12 +233,12 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu
|
||||
|
||||
total_dist = add_distributions(total_dist, dist)
|
||||
|
||||
# save result into the cache, then return it
|
||||
# note, total_dist is in pmf-form
|
||||
outcome = sum([percentile_distribution(total_dist, perc) for perc in perc_return]) / len(perc_return)
|
||||
# save result into the cache, then return it
|
||||
yachtdice_cache[tup] = max(5, math.floor(outcome)) # at least 5.
|
||||
|
||||
return yachtdice_cache[tup]
|
||||
|
||||
return yachtdice_cache[tup]
|
||||
|
||||
|
||||
def dice_simulation(state, player, options):
|
||||
@@ -243,7 +268,6 @@ def dice_simulation(state, player, options):
|
||||
return state.prog_items[player]["maximum_achievable_score"]
|
||||
|
||||
|
||||
|
||||
def set_yacht_rules(world: MultiWorld, player: int, options):
|
||||
"""
|
||||
Sets rules on entrances and advancements that are always applied
|
||||
@@ -262,4 +286,3 @@ def set_yacht_completion_rules(world: MultiWorld, player: int):
|
||||
Sets rules on completion condition
|
||||
"""
|
||||
world.completion_condition[player] = lambda state: state.has("Victory", player)
|
||||
|
||||
Reference in New Issue
Block a user