mabby
A multi-armed bandit (MAB) simulation library.
mabby is a library for simulating multi-armed bandits (MABs), a resource-allocation problem and framework in reinforcement learning. It allows users to quickly yet flexibly define and run bandit simulations, with the ability to:
- choose from a wide range of classic bandit algorithms to use
- configure environments with custom arm spaces and rewards distributions
- collect and visualize simulation metrics like regret and optimality
Agent(strategy, name=None)
Agent in a multi-armed bandit simulation.
An agent represents an autonomous entity in a bandit simulation. It wraps around a specified strategy and provides an interface for each part of the decision-making process, including making a choice then updating internal parameter estimates based on the observed rewards from that choice.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
strategy |
Strategy
|
The bandit strategy to use. |
required |
name |
str | None
|
An optional name for the agent. |
None
|
Source code in mabby/agent.py
28 29 30 31 32 33 34 35 36 37 38 |
|
Ns: NDArray[np.uint32]
property
The number of times the agent has played each arm.
The play counts are only available after the agent has been primed.
Returns:
Type | Description |
---|---|
NDArray[np.uint32]
|
An array of the play counts of each arm. |
Raises:
Type | Description |
---|---|
AgentUsageError
|
If the agent has not been primed. |
Qs: NDArray[np.float64]
property
The agent's current estimated action values (Q-values).
The action values are only available after the agent has been primed.
Returns:
Type | Description |
---|---|
NDArray[np.float64]
|
An array of the action values of each arm. |
Raises:
Type | Description |
---|---|
AgentUsageError
|
If the agent has not been primed. |
__repr__()
Returns the agent's string representation.
Uses the agent's name if set. Otherwise, the string representation of the agent's strategy is used by default.
Source code in mabby/agent.py
40 41 42 43 44 45 46 47 48 |
|
choose()
Returns the agent's next choice based on its strategy.
This method can only be called on a primed agent.
Returns:
Type | Description |
---|---|
int
|
The index of the arm chosen by the agent. |
Raises:
Type | Description |
---|---|
AgentUsageError
|
If the agent has not been primed. |
Source code in mabby/agent.py
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
prime(k, steps, rng)
Primes the agent before running a trial.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
k |
int
|
The number of bandit arms for the agent to choose from. |
required |
steps |
int
|
The number of steps to the simulation will be run. |
required |
rng |
Generator
|
A random number generator. |
required |
Source code in mabby/agent.py
50 51 52 53 54 55 56 57 58 59 60 61 |
|
update(reward)
Updates the agent's internal parameter estimates.
This method can only be called if the agent has previously made a choice, and an update based on that choice has not already been made.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reward |
float
|
The observed reward from the agent's most recent choice. |
required |
Raises:
Type | Description |
---|---|
AgentUsageError
|
If the agent has not previously made a choice. |
Source code in mabby/agent.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
|
Arm(**kwargs)
Bases: ABC
, EnforceOverrides
Base class for a bandit arm implementing a reward distribution.
An arm represents one of the decision choices available to the agent in a bandit problem. It has a hidden reward distribution and can be played by the agent to generate observable rewards.
Source code in mabby/arms.py
21 22 23 |
|
mean: float
property
abstractmethod
The mean reward of the arm.
Returns:
Type | Description |
---|---|
float
|
The computed mean of the arm's reward distribution. |
__repr__()
abstractmethod
Returns the string representation of the arm.
Source code in mabby/arms.py
45 46 47 |
|
bandit(rng=None, seed=None, **kwargs)
classmethod
Creates a bandit with arms of the same reward distribution type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rng |
Generator | None
|
A random number generator. |
None
|
seed |
int | None
|
A seed for random number generation if |
None
|
**kwargs |
list[float]
|
A dictionary where keys are arm parameter names and values are lists of parameter values for each arm. |
{}
|
Returns:
Type | Description |
---|---|
Bandit
|
A bandit with the specified arms. |
Source code in mabby/arms.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
|
play(rng)
abstractmethod
Plays the arm and samples a reward.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rng |
Generator
|
A random number generator. |
required |
Returns:
Type | Description |
---|---|
float
|
The sampled reward from the arm's reward distribution. |
Source code in mabby/arms.py
25 26 27 28 29 30 31 32 33 34 |
|
Bandit(arms, rng=None, seed=None)
Multi-armed bandit with one or more arms.
This class wraps around a list of arms, each of which has a reward distribution. It provides an interface for interacting with the arms, such as playing a specific arm, querying for the optimal arm, and computing regret from a given choice.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arms |
list[Arm]
|
A list of arms for the bandit. |
required |
rng |
Generator | None
|
A random number generator. |
None
|
seed |
int | None
|
A seed for random number generation if |
None
|
Source code in mabby/bandit.py
24 25 26 27 28 29 30 31 32 33 34 35 |
|
means: list[float]
property
__getitem__(i)
Returns an arm by index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
i |
int
|
The index of the arm to get. |
required |
Returns:
Type | Description |
---|---|
Arm
|
The arm at the given index. |
Source code in mabby/bandit.py
45 46 47 48 49 50 51 52 53 54 |
|
__iter__()
Returns an iterator over the bandit's arms.
Source code in mabby/bandit.py
56 57 58 |
|
__len__()
Returns the number of arms.
Source code in mabby/bandit.py
37 38 39 |
|
__repr__()
Returns a string representation of the bandit.
Source code in mabby/bandit.py
41 42 43 |
|
best_arm()
Returns the index of the optimal arm.
The optimal arm is the arm with the greatest expected reward. If there are multiple arms with equal expected rewards, a random one is chosen.
Returns:
Type | Description |
---|---|
int
|
The index of the optimal arm. |
Source code in mabby/bandit.py
80 81 82 83 84 85 86 87 88 89 |
|
is_opt(choice)
Returns the optimality of a given choice.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
choice |
int
|
The index of the chosen arm. |
required |
Returns:
Type | Description |
---|---|
bool
|
|
Source code in mabby/bandit.py
91 92 93 94 95 96 97 98 99 100 |
|
play(i)
Plays an arm by index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
i |
int
|
The index of the arm to play. |
required |
Returns:
Type | Description |
---|---|
float
|
The reward from playing the arm. |
Source code in mabby/bandit.py
60 61 62 63 64 65 66 67 68 69 |
|
regret(choice)
Returns the regret from a given choice.
The regret is computed as the difference between the expected reward from the optimal arm and the expected reward from the chosen arm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
choice |
int
|
The index of the chosen arm. |
required |
Returns:
Type | Description |
---|---|
float
|
The computed regret value. |
Source code in mabby/bandit.py
102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
BernoulliArm(p)
Bases: Arm
Bandit arm with a Bernoulli reward distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p |
float
|
Parameter of the Bernoulli distribution. |
required |
Source code in mabby/arms.py
76 77 78 79 80 81 82 83 84 85 86 87 |
|
GaussianArm(loc, scale)
Bases: Arm
Bandit arm with a Gaussian reward distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loc |
float
|
Mean ("center") of the Gaussian distribution. |
required |
scale |
float
|
Standard deviation of the Gaussian distribution. |
required |
Source code in mabby/arms.py
106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
|
Metric(label, base=None, transform=None)
Bases: Enum
Enum for metrics that simulations can track.
Metrics can be derived from other metrics through specifying a base
metric
and a transform
function. This is useful for things like defining cumulative
versions of an existing metric, where the transformed values can be computed
"lazily" instead of being redundantly stored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label |
str
|
Verbose name of the metric (title case) |
required |
base |
str | None
|
Name of the base metric |
None
|
transform |
Callable[[NDArray[np.float64]], NDArray[np.float64]] | None
|
Transformation function from the base metric |
None
|
Source code in mabby/stats.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
base: Metric
property
The base metric that the metric is transformed from.
If the metric is already a base metric, the metric itself is returned.
__repr__()
Returns the verbose name of the metric.
Source code in mabby/stats.py
71 72 73 |
|
is_base()
Returns whether the metric is a base metric.
Returns:
Type | Description |
---|---|
bool
|
|
Source code in mabby/stats.py
75 76 77 78 79 80 81 |
|
map_to_base(metrics)
classmethod
Traces all metrics back to their base metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metrics |
Iterable[Metric]
|
A collection of metrics. |
required |
Returns:
Type | Description |
---|---|
Iterable[Metric]
|
A set containing the base metrics of all the input metrics. |
Source code in mabby/stats.py
93 94 95 96 97 98 99 100 101 102 103 |
|
transform(values)
Transforms values from the base metric.
If the metric is already a base metric, the input values are returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
NDArray[np.float64]
|
An array of input values for the base metric. |
required |
Returns:
Type | Description |
---|---|
NDArray[np.float64]
|
An array of transformed values for the metric. |
Source code in mabby/stats.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
|
Simulation(bandit, agents=None, strategies=None, names=None, rng=None, seed=None)
Simulation of a multi-armed bandit problem.
A simulation consists of multiple trials of one or more bandit strategies run on a configured multi-armed bandit.
One of agents
or strategies
must be supplied. If agents
is supplied,
strategies
and names
are ignored. Otherwise, an agent
is created for
each strategy
and given a name from names
if available.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bandit |
Bandit
|
A configured multi-armed bandit to simulate on. |
required |
agents |
Iterable[Agent] | None
|
A list of agents to simulate. |
None
|
strategies |
Iterable[Strategy] | None
|
A list of strategies to simulate. |
None
|
names |
Iterable[str] | None
|
A list of names for agents. |
None
|
rng |
Generator | None
|
A random number generator. |
None
|
seed |
int | None
|
A seed for random number generation if |
None
|
Raises:
Type | Description |
---|---|
SimulationUsageError
|
If neither |
Source code in mabby/simulation.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
run(trials, steps, metrics=None)
Runs a simulation.
In a simulation run, each agent or strategy is run for the specified number of trials, and each trial is run for the given number of steps.
If metrics
is not specified, all available metrics are tracked by default.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trials |
int
|
The number of trials in the simulation. |
required |
steps |
int
|
The number of steps in a trial. |
required |
metrics |
Iterable[Metric] | None
|
A list of metrics to collect. |
None
|
Returns:
Type | Description |
---|---|
SimulationStats
|
A |
Source code in mabby/simulation.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
|
SimulationStats(simulation)
Statistics for a multi-armed bandit simulation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
simulation |
Simulation
|
The simulation to track. |
required |
Source code in mabby/stats.py
124 125 126 127 128 129 130 131 |
|
__contains__(agent)
Returns if an agent's statistics are present.
Returns:
Type | Description |
---|---|
bool
|
|
Source code in mabby/stats.py
163 164 165 166 167 168 169 |
|
__getitem__(agent)
Gets statistics for an agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent |
Agent
|
The agent to get the statistics of. |
required |
Returns:
Type | Description |
---|---|
AgentStats
|
The statistics of the agent. |
Source code in mabby/stats.py
141 142 143 144 145 146 147 148 149 150 |
|
__setitem__(agent, agent_stats)
Sets the statistics for an agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent |
Agent
|
The agent to set the statistics of. |
required |
agent_stats |
AgentStats
|
The agent statistics to set. |
required |
Source code in mabby/stats.py
152 153 154 155 156 157 158 159 160 161 |
|
add(agent_stats)
Adds statistics for an agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_stats |
AgentStats
|
The agent statistics to add. |
required |
Source code in mabby/stats.py
133 134 135 136 137 138 139 |
|
plot(metric)
Generates a plot for a simulation metric.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric |
Metric
|
The metric to plot. |
required |
Source code in mabby/stats.py
171 172 173 174 175 176 177 178 179 180 |
|
plot_optimality()
Generates a plot for the optimality metric.
Source code in mabby/stats.py
190 191 192 |
|
plot_regret(cumulative=True)
Generates a plot for the regret or cumulative regret metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cumulative |
bool
|
Whether to use the cumulative regret. |
True
|
Source code in mabby/stats.py
182 183 184 185 186 187 188 |
|
plot_rewards(cumulative=True)
Generates a plot for the rewards or cumulative rewards metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cumulative |
bool
|
Whether to use the cumulative rewards. |
True
|
Source code in mabby/stats.py
194 195 196 197 198 199 200 |
|
Strategy()
Bases: ABC
, EnforceOverrides
Base class for a bandit strategy.
A strategy provides the computational logic for choosing which bandit arms to play and updating parameter estimates.
Source code in mabby/strategies/strategy.py
22 23 24 |
|
Ns: NDArray[np.uint32]
property
abstractmethod
The number of times each arm has been played.
Qs: NDArray[np.float64]
property
abstractmethod
The current estimated action values for each arm.
__repr__()
abstractmethod
Returns a string representation of the strategy.
Source code in mabby/strategies/strategy.py
26 27 28 |
|
agent(**kwargs)
Creates an agent following the strategy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
str
|
Parameters for initializing the agent (see
|
{}
|
Returns:
Type | Description |
---|---|
Agent
|
The created agent with the strategy. |
Source code in mabby/strategies/strategy.py
70 71 72 73 74 75 76 77 78 79 80 |
|
choose(rng)
abstractmethod
Returns the next arm to play.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rng |
Generator
|
A random number generator. |
required |
Returns:
Type | Description |
---|---|
int
|
The index of the arm to play. |
Source code in mabby/strategies/strategy.py
39 40 41 42 43 44 45 46 47 48 |
|
prime(k, steps)
abstractmethod
Primes the strategy before running a trial.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
k |
int
|
The number of bandit arms to choose from. |
required |
steps |
int
|
The number of steps to the simulation will be run. |
required |
Source code in mabby/strategies/strategy.py
30 31 32 33 34 35 36 37 |
|
update(choice, reward, rng=None)
abstractmethod
Updates internal parameter estimates based on reward observation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
choice |
int
|
The most recent choice made. |
required |
reward |
float
|
The observed reward from the agent's most recent choice. |
required |
rng |
Generator | None
|
A random number generator. |
None
|
Source code in mabby/strategies/strategy.py
50 51 52 53 54 55 56 57 58 |
|