Skip to content

mabby

A multi-armed bandit (MAB) simulation library.

mabby is a library for simulating multi-armed bandits (MABs), a resource-allocation problem and framework in reinforcement learning. It allows users to quickly yet flexibly define and run bandit simulations, with the ability to:

  • choose from a wide range of classic bandit algorithms to use
  • configure environments with custom arm spaces and rewards distributions
  • collect and visualize simulation metrics like regret and optimality

Agent(strategy, name=None)

Agent in a multi-armed bandit simulation.

An agent represents an autonomous entity in a bandit simulation. It wraps around a specified strategy and provides an interface for each part of the decision-making process, including making a choice then updating internal parameter estimates based on the observed rewards from that choice.

Parameters:

Name Type Description Default
strategy Strategy

The bandit strategy to use.

required
name str | None

An optional name for the agent.

None
Source code in mabby/agent.py
28
29
30
31
32
33
34
35
36
37
38
def __init__(self, strategy: Strategy, name: str | None = None):
    """Initializes an agent with a given strategy.

    Args:
        strategy: The bandit strategy to use.
        name: An optional name for the agent.
    """
    self.strategy: Strategy = strategy  #: The bandit strategy to use
    self._name = name
    self._primed = False
    self._choice: int | None = None

Ns: NDArray[np.uint32] property

The number of times the agent has played each arm.

The play counts are only available after the agent has been primed.

Returns:

Type Description
NDArray[np.uint32]

An array of the play counts of each arm.

Raises:

Type Description
AgentUsageError

If the agent has not been primed.

Qs: NDArray[np.float64] property

The agent's current estimated action values (Q-values).

The action values are only available after the agent has been primed.

Returns:

Type Description
NDArray[np.float64]

An array of the action values of each arm.

Raises:

Type Description
AgentUsageError

If the agent has not been primed.

__repr__()

Returns the agent's string representation.

Uses the agent's name if set. Otherwise, the string representation of the agent's strategy is used by default.

Source code in mabby/agent.py
40
41
42
43
44
45
46
47
48
def __repr__(self) -> str:
    """Returns the agent's string representation.

    Uses the agent's name if set. Otherwise, the string representation of the
    agent's strategy is used by default.
    """
    if self._name is None:
        return str(self.strategy)
    return self._name

choose()

Returns the agent's next choice based on its strategy.

This method can only be called on a primed agent.

Returns:

Type Description
int

The index of the arm chosen by the agent.

Raises:

Type Description
AgentUsageError

If the agent has not been primed.

Source code in mabby/agent.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def choose(self) -> int:
    """Returns the agent's next choice based on its strategy.

    This method can only be called on a primed agent.

    Returns:
        The index of the arm chosen by the agent.

    Raises:
        AgentUsageError: If the agent has not been primed.
    """
    if not self._primed:
        raise AgentUsageError("choose() can only be called on a primed agent")
    self._choice = self.strategy.choose(self._rng)
    return self._choice

prime(k, steps, rng)

Primes the agent before running a trial.

Parameters:

Name Type Description Default
k int

The number of bandit arms for the agent to choose from.

required
steps int

The number of steps to the simulation will be run.

required
rng Generator

A random number generator.

required
Source code in mabby/agent.py
50
51
52
53
54
55
56
57
58
59
60
61
def prime(self, k: int, steps: int, rng: Generator) -> None:
    """Primes the agent before running a trial.

    Args:
        k: The number of bandit arms for the agent to choose from.
        steps: The number of steps to the simulation will be run.
        rng: A random number generator.
    """
    self._primed = True
    self._choice = None
    self._rng = rng
    self.strategy.prime(k, steps)

update(reward)

Updates the agent's internal parameter estimates.

This method can only be called if the agent has previously made a choice, and an update based on that choice has not already been made.

Parameters:

Name Type Description Default
reward float

The observed reward from the agent's most recent choice.

required

Raises:

Type Description
AgentUsageError

If the agent has not previously made a choice.

Source code in mabby/agent.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def update(self, reward: float) -> None:
    """Updates the agent's internal parameter estimates.

    This method can only be called if the agent has previously made a choice, and
    an update based on that choice has not already been made.

    Args:
        reward: The observed reward from the agent's most recent choice.

    Raises:
        AgentUsageError: If the agent has not previously made a choice.
    """
    if self._choice is None:
        raise AgentUsageError("update() can only be called after choose()")
    self.strategy.update(self._choice, reward, self._rng)
    self._choice = None

Arm(**kwargs)

Bases: ABC, EnforceOverrides

Base class for a bandit arm implementing a reward distribution.

An arm represents one of the decision choices available to the agent in a bandit problem. It has a hidden reward distribution and can be played by the agent to generate observable rewards.

Source code in mabby/arms.py
21
22
23
@abstractmethod
def __init__(self, **kwargs: float):
    """Initializes an arm."""

mean: float property abstractmethod

The mean reward of the arm.

Returns:

Type Description
float

The computed mean of the arm's reward distribution.

__repr__() abstractmethod

Returns the string representation of the arm.

Source code in mabby/arms.py
45
46
47
@abstractmethod
def __repr__(self) -> str:
    """Returns the string representation of the arm."""

bandit(rng=None, seed=None, **kwargs) classmethod

Creates a bandit with arms of the same reward distribution type.

Parameters:

Name Type Description Default
rng Generator | None

A random number generator.

None
seed int | None

A seed for random number generation if rng is not provided.

None
**kwargs list[float]

A dictionary where keys are arm parameter names and values are lists of parameter values for each arm.

{}

Returns:

Type Description
Bandit

A bandit with the specified arms.

Source code in mabby/arms.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@classmethod
def bandit(
    cls,
    rng: Generator | None = None,
    seed: int | None = None,
    **kwargs: list[float],
) -> Bandit:
    """Creates a bandit with arms of the same reward distribution type.

    Args:
        rng: A random number generator.
        seed: A seed for random number generation if ``rng`` is not provided.
        **kwargs: A dictionary where keys are arm parameter names and values are
            lists of parameter values for each arm.

    Returns:
        A bandit with the specified arms.
    """
    params_dicts = [dict(zip(kwargs, t)) for t in zip(*kwargs.values())]
    if len(params_dicts) == 0:
        raise ValueError("insufficient parameters to create an arm")
    return Bandit([cls(**params) for params in params_dicts], rng, seed)

play(rng) abstractmethod

Plays the arm and samples a reward.

Parameters:

Name Type Description Default
rng Generator

A random number generator.

required

Returns:

Type Description
float

The sampled reward from the arm's reward distribution.

Source code in mabby/arms.py
25
26
27
28
29
30
31
32
33
34
@abstractmethod
def play(self, rng: Generator) -> float:
    """Plays the arm and samples a reward.

    Args:
        rng: A random number generator.

    Returns:
        The sampled reward from the arm's reward distribution.
    """

Bandit(arms, rng=None, seed=None)

Multi-armed bandit with one or more arms.

This class wraps around a list of arms, each of which has a reward distribution. It provides an interface for interacting with the arms, such as playing a specific arm, querying for the optimal arm, and computing regret from a given choice.

Parameters:

Name Type Description Default
arms list[Arm]

A list of arms for the bandit.

required
rng Generator | None

A random number generator.

None
seed int | None

A seed for random number generation if rng is not provided.

None
Source code in mabby/bandit.py
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(
    self, arms: list[Arm], rng: Generator | None = None, seed: int | None = None
):
    """Initializes a bandit with a given set of arms.

    Args:
        arms: A list of arms for the bandit.
        rng: A random number generator.
        seed: A seed for random number generation if ``rng`` is not provided.
    """
    self._arms = arms
    self._rng = rng if rng else np.random.default_rng(seed)

means: list[float] property

The means of the arms.

Returns:

Type Description
list[float]

An array of the means of each arm.

__getitem__(i)

Returns an arm by index.

Parameters:

Name Type Description Default
i int

The index of the arm to get.

required

Returns:

Type Description
Arm

The arm at the given index.

Source code in mabby/bandit.py
45
46
47
48
49
50
51
52
53
54
def __getitem__(self, i: int) -> Arm:
    """Returns an arm by index.

    Args:
        i: The index of the arm to get.

    Returns:
        The arm at the given index.
    """
    return self._arms[i]

__iter__()

Returns an iterator over the bandit's arms.

Source code in mabby/bandit.py
56
57
58
def __iter__(self) -> Iterable[Arm]:
    """Returns an iterator over the bandit's arms."""
    return iter(self._arms)

__len__()

Returns the number of arms.

Source code in mabby/bandit.py
37
38
39
def __len__(self) -> int:
    """Returns the number of arms."""
    return len(self._arms)

__repr__()

Returns a string representation of the bandit.

Source code in mabby/bandit.py
41
42
43
def __repr__(self) -> str:
    """Returns a string representation of the bandit."""
    return repr(self._arms)

best_arm()

Returns the index of the optimal arm.

The optimal arm is the arm with the greatest expected reward. If there are multiple arms with equal expected rewards, a random one is chosen.

Returns:

Type Description
int

The index of the optimal arm.

Source code in mabby/bandit.py
80
81
82
83
84
85
86
87
88
89
def best_arm(self) -> int:
    """Returns the index of the optimal arm.

    The optimal arm is the arm with the greatest expected reward. If there are
    multiple arms with equal expected rewards, a random one is chosen.

    Returns:
        The index of the optimal arm.
    """
    return random_argmax(self.means, rng=self._rng)

is_opt(choice)

Returns the optimality of a given choice.

Parameters:

Name Type Description Default
choice int

The index of the chosen arm.

required

Returns:

Type Description
bool

True if the arm has the greatest expected reward, False otherwise.

Source code in mabby/bandit.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def is_opt(self, choice: int) -> bool:
    """Returns the optimality of a given choice.

    Args:
        choice: The index of the chosen arm.

    Returns:
        ``True`` if the arm has the greatest expected reward, ``False`` otherwise.
    """
    return np.max(self.means) == self._arms[choice].mean

play(i)

Plays an arm by index.

Parameters:

Name Type Description Default
i int

The index of the arm to play.

required

Returns:

Type Description
float

The reward from playing the arm.

Source code in mabby/bandit.py
60
61
62
63
64
65
66
67
68
69
def play(self, i: int) -> float:
    """Plays an arm by index.

    Args:
        i: The index of the arm to play.

    Returns:
        The reward from playing the arm.
    """
    return self[i].play(self._rng)

regret(choice)

Returns the regret from a given choice.

The regret is computed as the difference between the expected reward from the optimal arm and the expected reward from the chosen arm.

Parameters:

Name Type Description Default
choice int

The index of the chosen arm.

required

Returns:

Type Description
float

The computed regret value.

Source code in mabby/bandit.py
102
103
104
105
106
107
108
109
110
111
112
113
114
def regret(self, choice: int) -> float:
    """Returns the regret from a given choice.

    The regret is computed as the difference between the expected reward from the
    optimal arm and the expected reward from the chosen arm.

    Args:
        choice: The index of the chosen arm.

    Returns:
        The computed regret value.
    """
    return np.max(self.means) - self._arms[choice].mean

BernoulliArm(p)

Bases: Arm

Bandit arm with a Bernoulli reward distribution.

Parameters:

Name Type Description Default
p float

Parameter of the Bernoulli distribution.

required
Source code in mabby/arms.py
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(self, p: float):
    """Initializes a Bernoulli arm.

    Args:
        p: Parameter of the Bernoulli distribution.
    """
    if p < 0 or p > 1:
        raise ValueError(
            f"float {str(p)} is not a valid probability for Bernoulli distribution"
        )

    self.p: float = p  #: Parameter of the Bernoulli distribution

GaussianArm(loc, scale)

Bases: Arm

Bandit arm with a Gaussian reward distribution.

Parameters:

Name Type Description Default
loc float

Mean ("center") of the Gaussian distribution.

required
scale float

Standard deviation of the Gaussian distribution.

required
Source code in mabby/arms.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(self, loc: float, scale: float):
    """Initializes a Gaussian arm.

    Args:
        loc: Mean ("center") of the Gaussian distribution.
        scale: Standard deviation of the Gaussian distribution.
    """
    if scale < 0:
        raise ValueError(
            f"float {str(scale)} is not a valid scale for Gaussian distribution"
        )

    self.loc: float = loc  #: Mean ("center") of the Gaussian distribution
    self.scale: float = scale  #: Standard deviation of the Gaussian distribution

Metric(label, base=None, transform=None)

Bases: Enum

Enum for metrics that simulations can track.

Metrics can be derived from other metrics through specifying a base metric and a transform function. This is useful for things like defining cumulative versions of an existing metric, where the transformed values can be computed "lazily" instead of being redundantly stored.

Parameters:

Name Type Description Default
label str

Verbose name of the metric (title case)

required
base str | None

Name of the base metric

None
transform Callable[[NDArray[np.float64]], NDArray[np.float64]] | None

Transformation function from the base metric

None
Source code in mabby/stats.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    label: str,
    base: str | None = None,
    transform: Callable[[NDArray[np.float64]], NDArray[np.float64]] | None = None,
):
    """Initializes a metric.

    Metrics can be derived from other metrics through specifying a ``base`` metric
    and a ``transform`` function. This is useful for things like defining cumulative
    versions of an existing metric, where the transformed values can be computed
    "lazily" instead of being redundantly stored.

    Args:
        label: Verbose name of the metric (title case)
        base: Name of the base metric
        transform: Transformation function from the base metric
    """
    self.__class__.__MAPPING__[self._name_] = self
    self._label = label
    self._mapping: MetricMapping | None = (
        MetricMapping(base=self.__class__.__MAPPING__[base], transform=transform)
        if base and transform
        else None
    )

base: Metric property

The base metric that the metric is transformed from.

If the metric is already a base metric, the metric itself is returned.

__repr__()

Returns the verbose name of the metric.

Source code in mabby/stats.py
71
72
73
def __repr__(self) -> str:
    """Returns the verbose name of the metric."""
    return self._label

is_base()

Returns whether the metric is a base metric.

Returns:

Type Description
bool

True if the metric is a base metric, False otherwise.

Source code in mabby/stats.py
75
76
77
78
79
80
81
def is_base(self) -> bool:
    """Returns whether the metric is a base metric.

    Returns:
        ``True`` if the metric is a base metric, ``False`` otherwise.
    """
    return self._mapping is None

map_to_base(metrics) classmethod

Traces all metrics back to their base metrics.

Parameters:

Name Type Description Default
metrics Iterable[Metric]

A collection of metrics.

required

Returns:

Type Description
Iterable[Metric]

A set containing the base metrics of all the input metrics.

Source code in mabby/stats.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@classmethod
def map_to_base(cls, metrics: Iterable[Metric]) -> Iterable[Metric]:
    """Traces all metrics back to their base metrics.

    Args:
        metrics: A collection of metrics.

    Returns:
        A set containing the base metrics of all the input metrics.
    """
    return set(m.base for m in metrics)

transform(values)

Transforms values from the base metric.

If the metric is already a base metric, the input values are returned.

Parameters:

Name Type Description Default
values NDArray[np.float64]

An array of input values for the base metric.

required

Returns:

Type Description
NDArray[np.float64]

An array of transformed values for the metric.

Source code in mabby/stats.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def transform(self, values: NDArray[np.float64]) -> NDArray[np.float64]:
    """Transforms values from the base metric.

    If the metric is already a base metric, the input values are returned.

    Args:
        values: An array of input values for the base metric.

    Returns:
        An array of transformed values for the metric.
    """
    if self._mapping is not None:
        return self._mapping.transform(values)
    return values

Simulation(bandit, agents=None, strategies=None, names=None, rng=None, seed=None)

Simulation of a multi-armed bandit problem.

A simulation consists of multiple trials of one or more bandit strategies run on a configured multi-armed bandit.

One of agents or strategies must be supplied. If agents is supplied, strategies and names are ignored. Otherwise, an agent is created for each strategy and given a name from names if available.

Parameters:

Name Type Description Default
bandit Bandit

A configured multi-armed bandit to simulate on.

required
agents Iterable[Agent] | None

A list of agents to simulate.

None
strategies Iterable[Strategy] | None

A list of strategies to simulate.

None
names Iterable[str] | None

A list of names for agents.

None
rng Generator | None

A random number generator.

None
seed int | None

A seed for random number generation if rng is not provided.

None

Raises:

Type Description
SimulationUsageError

If neither agents nor strategies are supplied.

Source code in mabby/simulation.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    bandit: Bandit,
    agents: Iterable[Agent] | None = None,
    strategies: Iterable[Strategy] | None = None,
    names: Iterable[str] | None = None,
    rng: Generator | None = None,
    seed: int | None = None,
):
    """Initializes a simulation.

    One of ``agents`` or ``strategies`` must be supplied. If ``agents`` is supplied,
    ``strategies`` and ``names`` are ignored. Otherwise, an ``agent`` is created for
    each ``strategy`` and given a name from ``names`` if available.

    Args:
        bandit: A configured multi-armed bandit to simulate on.
        agents: A list of agents to simulate.
        strategies: A list of strategies to simulate.
        names: A list of names for agents.
        rng: A random number generator.
        seed: A seed for random number generation if ``rng`` is not provided.

    Raises:
        SimulationUsageError: If neither ``agents`` nor ``strategies`` are supplied.
    """
    self.agents = self._create_agents(agents, strategies, names)
    if len(list(self.agents)) == 0:
        raise ValueError("no strategies or agents were supplied")
    self.bandit = bandit
    if len(self.bandit) == 0:
        raise ValueError("bandit cannot be empty")
    self._rng = rng if rng else np.random.default_rng(seed)

run(trials, steps, metrics=None)

Runs a simulation.

In a simulation run, each agent or strategy is run for the specified number of trials, and each trial is run for the given number of steps.

If metrics is not specified, all available metrics are tracked by default.

Parameters:

Name Type Description Default
trials int

The number of trials in the simulation.

required
steps int

The number of steps in a trial.

required
metrics Iterable[Metric] | None

A list of metrics to collect.

None

Returns:

Type Description
SimulationStats

A SimulationStats object with the results of the simulation.

Source code in mabby/simulation.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def run(
    self, trials: int, steps: int, metrics: Iterable[Metric] | None = None
) -> SimulationStats:
    """Runs a simulation.

    In a simulation run, each agent or strategy is run for the specified number of
    trials, and each trial is run for the given number of steps.

    If ``metrics`` is not specified, all available metrics are tracked by default.

    Args:
        trials: The number of trials in the simulation.
        steps: The number of steps in a trial.
        metrics: A list of metrics to collect.

    Returns:
        A ``SimulationStats`` object with the results of the simulation.
    """
    sim_stats = SimulationStats(simulation=self)
    for agent in self.agents:
        agent_stats = self._run_trials_for_agent(agent, trials, steps, metrics)
        sim_stats.add(agent_stats)
    return sim_stats

SimulationStats(simulation)

Statistics for a multi-armed bandit simulation.

Parameters:

Name Type Description Default
simulation Simulation

The simulation to track.

required
Source code in mabby/stats.py
124
125
126
127
128
129
130
131
def __init__(self, simulation: Simulation):
    """Initializes simulation statistics.

    Args:
        simulation: The simulation to track.
    """
    self._simulation: Simulation = simulation
    self._stats_dict: dict[Agent, AgentStats] = {}

__contains__(agent)

Returns if an agent's statistics are present.

Returns:

Type Description
bool

True if an agent's statistics are present, False otherwise.

Source code in mabby/stats.py
163
164
165
166
167
168
169
def __contains__(self, agent: Agent) -> bool:
    """Returns if an agent's statistics are present.

    Returns:
        ``True`` if an agent's statistics are present, ``False`` otherwise.
    """
    return agent in self._stats_dict

__getitem__(agent)

Gets statistics for an agent.

Parameters:

Name Type Description Default
agent Agent

The agent to get the statistics of.

required

Returns:

Type Description
AgentStats

The statistics of the agent.

Source code in mabby/stats.py
141
142
143
144
145
146
147
148
149
150
def __getitem__(self, agent: Agent) -> AgentStats:
    """Gets statistics for an agent.

    Args:
        agent: The agent to get the statistics of.

    Returns:
        The statistics of the agent.
    """
    return self._stats_dict[agent]

__setitem__(agent, agent_stats)

Sets the statistics for an agent.

Parameters:

Name Type Description Default
agent Agent

The agent to set the statistics of.

required
agent_stats AgentStats

The agent statistics to set.

required
Source code in mabby/stats.py
152
153
154
155
156
157
158
159
160
161
def __setitem__(self, agent: Agent, agent_stats: AgentStats) -> None:
    """Sets the statistics for an agent.

    Args:
        agent: The agent to set the statistics of.
        agent_stats: The agent statistics to set.
    """
    if agent != agent_stats.agent:
        raise StatsUsageError("agents specified in key and value don't match")
    self._stats_dict[agent] = agent_stats

add(agent_stats)

Adds statistics for an agent.

Parameters:

Name Type Description Default
agent_stats AgentStats

The agent statistics to add.

required
Source code in mabby/stats.py
133
134
135
136
137
138
139
def add(self, agent_stats: AgentStats) -> None:
    """Adds statistics for an agent.

    Args:
        agent_stats: The agent statistics to add.
    """
    self._stats_dict[agent_stats.agent] = agent_stats

plot(metric)

Generates a plot for a simulation metric.

Parameters:

Name Type Description Default
metric Metric

The metric to plot.

required
Source code in mabby/stats.py
171
172
173
174
175
176
177
178
179
180
def plot(self, metric: Metric) -> None:
    """Generates a plot for a simulation metric.

    Args:
        metric: The metric to plot.
    """
    for agent, agent_stats in self._stats_dict.items():
        plt.plot(agent_stats[metric], label=str(agent))
    plt.legend()
    plt.show()

plot_optimality()

Generates a plot for the optimality metric.

Source code in mabby/stats.py
190
191
192
def plot_optimality(self) -> None:
    """Generates a plot for the optimality metric."""
    self.plot(metric=Metric.OPTIMALITY)

plot_regret(cumulative=True)

Generates a plot for the regret or cumulative regret metrics.

Parameters:

Name Type Description Default
cumulative bool

Whether to use the cumulative regret.

True
Source code in mabby/stats.py
182
183
184
185
186
187
188
def plot_regret(self, cumulative: bool = True) -> None:
    """Generates a plot for the regret or cumulative regret metrics.

    Args:
        cumulative: Whether to use the cumulative regret.
    """
    self.plot(metric=Metric.CUM_REGRET if cumulative else Metric.REGRET)

plot_rewards(cumulative=True)

Generates a plot for the rewards or cumulative rewards metrics.

Parameters:

Name Type Description Default
cumulative bool

Whether to use the cumulative rewards.

True
Source code in mabby/stats.py
194
195
196
197
198
199
200
def plot_rewards(self, cumulative: bool = True) -> None:
    """Generates a plot for the rewards or cumulative rewards metrics.

    Args:
        cumulative: Whether to use the cumulative rewards.
    """
    self.plot(metric=Metric.CUM_REWARDS if cumulative else Metric.REWARDS)

Strategy()

Bases: ABC, EnforceOverrides

Base class for a bandit strategy.

A strategy provides the computational logic for choosing which bandit arms to play and updating parameter estimates.

Source code in mabby/strategies/strategy.py
22
23
24
@abstractmethod
def __init__(self) -> None:
    """Initializes a bandit strategy."""

Ns: NDArray[np.uint32] property abstractmethod

The number of times each arm has been played.

Qs: NDArray[np.float64] property abstractmethod

The current estimated action values for each arm.

__repr__() abstractmethod

Returns a string representation of the strategy.

Source code in mabby/strategies/strategy.py
26
27
28
@abstractmethod
def __repr__(self) -> str:
    """Returns a string representation of the strategy."""

agent(**kwargs)

Creates an agent following the strategy.

Parameters:

Name Type Description Default
**kwargs str

Parameters for initializing the agent (see Agent)

{}

Returns:

Type Description
Agent

The created agent with the strategy.

Source code in mabby/strategies/strategy.py
70
71
72
73
74
75
76
77
78
79
80
def agent(self, **kwargs: str) -> Agent:
    """Creates an agent following the strategy.

    Args:
        **kwargs: Parameters for initializing the agent (see
            [`Agent`][mabby.agent.Agent])

    Returns:
        The created agent with the strategy.
    """
    return Agent(strategy=self, **kwargs)

choose(rng) abstractmethod

Returns the next arm to play.

Parameters:

Name Type Description Default
rng Generator

A random number generator.

required

Returns:

Type Description
int

The index of the arm to play.

Source code in mabby/strategies/strategy.py
39
40
41
42
43
44
45
46
47
48
@abstractmethod
def choose(self, rng: Generator) -> int:
    """Returns the next arm to play.

    Args:
        rng: A random number generator.

    Returns:
        The index of the arm to play.
    """

prime(k, steps) abstractmethod

Primes the strategy before running a trial.

Parameters:

Name Type Description Default
k int

The number of bandit arms to choose from.

required
steps int

The number of steps to the simulation will be run.

required
Source code in mabby/strategies/strategy.py
30
31
32
33
34
35
36
37
@abstractmethod
def prime(self, k: int, steps: int) -> None:
    """Primes the strategy before running a trial.

    Args:
        k: The number of bandit arms to choose from.
        steps: The number of steps to the simulation will be run.
    """

update(choice, reward, rng=None) abstractmethod

Updates internal parameter estimates based on reward observation.

Parameters:

Name Type Description Default
choice int

The most recent choice made.

required
reward float

The observed reward from the agent's most recent choice.

required
rng Generator | None

A random number generator.

None
Source code in mabby/strategies/strategy.py
50
51
52
53
54
55
56
57
58
@abstractmethod
def update(self, choice: int, reward: float, rng: Generator | None = None) -> None:
    """Updates internal parameter estimates based on reward observation.

    Args:
        choice: The most recent choice made.
        reward: The observed reward from the agent's most recent choice.
        rng: A random number generator.
    """