Skip to content

stats

Provides metric tracking for multi-armed bandit simulations.

AgentStats(agent, bandit, steps, metrics=None)

Statistics for an agent in a multi-armed bandit simulation.

All available metrics are tracked by default. Alternatively, a specific list can be specified through the metrics argument.

Parameters:

Name Type Description Default
agent Agent

The agent that statistics are tracked for

required
bandit Bandit

The bandit of the simulation being run

required
steps int

The number of steps per trial in the simulation

required
metrics Iterable[Metric] | None

A collection of metrics to track.

None
Source code in mabby/stats.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def __init__(
    self,
    agent: Agent,
    bandit: Bandit,
    steps: int,
    metrics: Iterable[Metric] | None = None,
):
    """Initializes agent statistics.

    All available metrics are tracked by default. Alternatively, a specific list can
    be specified through the ``metrics`` argument.

    Args:
        agent: The agent that statistics are tracked for
        bandit: The bandit of the simulation being run
        steps: The number of steps per trial in the simulation
        metrics: A collection of metrics to track.
    """
    self.agent: Agent = agent  #: The agent that statistics are tracked for
    self._bandit = bandit
    self._steps = steps
    self._counts = np.zeros(steps)

    base_metrics = Metric.map_to_base(list(Metric) if metrics is None else metrics)
    self._stats = {stat: np.zeros(steps) for stat in base_metrics}

__getitem__(metric)

Gets values for a metric.

If the metric is not a base metric, the values are automatically transformed.

Parameters:

Name Type Description Default
metric Metric

The metric to get the values for.

required

Returns:

Type Description
NDArray[np.float64]

An array of values for the metric.

Source code in mabby/stats.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def __getitem__(self, metric: Metric) -> NDArray[np.float64]:
    """Gets values for a metric.

    If the metric is not a base metric, the values are automatically transformed.

    Args:
        metric: The metric to get the values for.

    Returns:
        An array of values for the metric.
    """
    with np.errstate(divide="ignore", invalid="ignore"):
        values = self._stats[metric.base] / self._counts
    return metric.transform(values)

__len__()

Returns the number of steps each trial is tracked for.

Source code in mabby/stats.py
232
233
234
def __len__(self) -> int:
    """Returns the number of steps each trial is tracked for."""
    return self._steps

update(step, choice, reward)

Updates metric values for the latest simulation step.

Parameters:

Name Type Description Default
step int

The number of the step.

required
choice int

The choice made by the agent.

required
reward float

The reward observed by the agent.

required
Source code in mabby/stats.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def update(self, step: int, choice: int, reward: float) -> None:
    """Updates metric values for the latest simulation step.

    Args:
        step: The number of the step.
        choice: The choice made by the agent.
        reward: The reward observed by the agent.
    """
    regret = self._bandit.regret(choice)
    if Metric.REGRET in self._stats:
        self._stats[Metric.REGRET][step] += regret
    if Metric.OPTIMALITY in self._stats:
        self._stats[Metric.OPTIMALITY][step] += int(self._bandit.is_opt(choice))
    if Metric.REWARDS in self._stats:
        self._stats[Metric.REWARDS][step] += reward
    self._counts[step] += 1

Metric(label, base=None, transform=None)

Bases: Enum

Enum for metrics that simulations can track.

Metrics can be derived from other metrics through specifying a base metric and a transform function. This is useful for things like defining cumulative versions of an existing metric, where the transformed values can be computed "lazily" instead of being redundantly stored.

Parameters:

Name Type Description Default
label str

Verbose name of the metric (title case)

required
base str | None

Name of the base metric

None
transform Callable[[NDArray[np.float64]], NDArray[np.float64]] | None

Transformation function from the base metric

None
Source code in mabby/stats.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    label: str,
    base: str | None = None,
    transform: Callable[[NDArray[np.float64]], NDArray[np.float64]] | None = None,
):
    """Initializes a metric.

    Metrics can be derived from other metrics through specifying a ``base`` metric
    and a ``transform`` function. This is useful for things like defining cumulative
    versions of an existing metric, where the transformed values can be computed
    "lazily" instead of being redundantly stored.

    Args:
        label: Verbose name of the metric (title case)
        base: Name of the base metric
        transform: Transformation function from the base metric
    """
    self.__class__.__MAPPING__[self._name_] = self
    self._label = label
    self._mapping: MetricMapping | None = (
        MetricMapping(base=self.__class__.__MAPPING__[base], transform=transform)
        if base and transform
        else None
    )

base: Metric property

The base metric that the metric is transformed from.

If the metric is already a base metric, the metric itself is returned.

__repr__()

Returns the verbose name of the metric.

Source code in mabby/stats.py
71
72
73
def __repr__(self) -> str:
    """Returns the verbose name of the metric."""
    return self._label

is_base()

Returns whether the metric is a base metric.

Returns:

Type Description
bool

True if the metric is a base metric, False otherwise.

Source code in mabby/stats.py
75
76
77
78
79
80
81
def is_base(self) -> bool:
    """Returns whether the metric is a base metric.

    Returns:
        ``True`` if the metric is a base metric, ``False`` otherwise.
    """
    return self._mapping is None

map_to_base(metrics) classmethod

Traces all metrics back to their base metrics.

Parameters:

Name Type Description Default
metrics Iterable[Metric]

A collection of metrics.

required

Returns:

Type Description
Iterable[Metric]

A set containing the base metrics of all the input metrics.

Source code in mabby/stats.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@classmethod
def map_to_base(cls, metrics: Iterable[Metric]) -> Iterable[Metric]:
    """Traces all metrics back to their base metrics.

    Args:
        metrics: A collection of metrics.

    Returns:
        A set containing the base metrics of all the input metrics.
    """
    return set(m.base for m in metrics)

transform(values)

Transforms values from the base metric.

If the metric is already a base metric, the input values are returned.

Parameters:

Name Type Description Default
values NDArray[np.float64]

An array of input values for the base metric.

required

Returns:

Type Description
NDArray[np.float64]

An array of transformed values for the metric.

Source code in mabby/stats.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def transform(self, values: NDArray[np.float64]) -> NDArray[np.float64]:
    """Transforms values from the base metric.

    If the metric is already a base metric, the input values are returned.

    Args:
        values: An array of input values for the base metric.

    Returns:
        An array of transformed values for the metric.
    """
    if self._mapping is not None:
        return self._mapping.transform(values)
    return values

MetricMapping dataclass

Transformation from a base metric.

See Metric for examples of metric mappings.

SimulationStats(simulation)

Statistics for a multi-armed bandit simulation.

Parameters:

Name Type Description Default
simulation Simulation

The simulation to track.

required
Source code in mabby/stats.py
124
125
126
127
128
129
130
131
def __init__(self, simulation: Simulation):
    """Initializes simulation statistics.

    Args:
        simulation: The simulation to track.
    """
    self._simulation: Simulation = simulation
    self._stats_dict: dict[Agent, AgentStats] = {}

__contains__(agent)

Returns if an agent's statistics are present.

Returns:

Type Description
bool

True if an agent's statistics are present, False otherwise.

Source code in mabby/stats.py
163
164
165
166
167
168
169
def __contains__(self, agent: Agent) -> bool:
    """Returns if an agent's statistics are present.

    Returns:
        ``True`` if an agent's statistics are present, ``False`` otherwise.
    """
    return agent in self._stats_dict

__getitem__(agent)

Gets statistics for an agent.

Parameters:

Name Type Description Default
agent Agent

The agent to get the statistics of.

required

Returns:

Type Description
AgentStats

The statistics of the agent.

Source code in mabby/stats.py
141
142
143
144
145
146
147
148
149
150
def __getitem__(self, agent: Agent) -> AgentStats:
    """Gets statistics for an agent.

    Args:
        agent: The agent to get the statistics of.

    Returns:
        The statistics of the agent.
    """
    return self._stats_dict[agent]

__setitem__(agent, agent_stats)

Sets the statistics for an agent.

Parameters:

Name Type Description Default
agent Agent

The agent to set the statistics of.

required
agent_stats AgentStats

The agent statistics to set.

required
Source code in mabby/stats.py
152
153
154
155
156
157
158
159
160
161
def __setitem__(self, agent: Agent, agent_stats: AgentStats) -> None:
    """Sets the statistics for an agent.

    Args:
        agent: The agent to set the statistics of.
        agent_stats: The agent statistics to set.
    """
    if agent != agent_stats.agent:
        raise StatsUsageError("agents specified in key and value don't match")
    self._stats_dict[agent] = agent_stats

add(agent_stats)

Adds statistics for an agent.

Parameters:

Name Type Description Default
agent_stats AgentStats

The agent statistics to add.

required
Source code in mabby/stats.py
133
134
135
136
137
138
139
def add(self, agent_stats: AgentStats) -> None:
    """Adds statistics for an agent.

    Args:
        agent_stats: The agent statistics to add.
    """
    self._stats_dict[agent_stats.agent] = agent_stats

plot(metric)

Generates a plot for a simulation metric.

Parameters:

Name Type Description Default
metric Metric

The metric to plot.

required
Source code in mabby/stats.py
171
172
173
174
175
176
177
178
179
180
def plot(self, metric: Metric) -> None:
    """Generates a plot for a simulation metric.

    Args:
        metric: The metric to plot.
    """
    for agent, agent_stats in self._stats_dict.items():
        plt.plot(agent_stats[metric], label=str(agent))
    plt.legend()
    plt.show()

plot_optimality()

Generates a plot for the optimality metric.

Source code in mabby/stats.py
190
191
192
def plot_optimality(self) -> None:
    """Generates a plot for the optimality metric."""
    self.plot(metric=Metric.OPTIMALITY)

plot_regret(cumulative=True)

Generates a plot for the regret or cumulative regret metrics.

Parameters:

Name Type Description Default
cumulative bool

Whether to use the cumulative regret.

True
Source code in mabby/stats.py
182
183
184
185
186
187
188
def plot_regret(self, cumulative: bool = True) -> None:
    """Generates a plot for the regret or cumulative regret metrics.

    Args:
        cumulative: Whether to use the cumulative regret.
    """
    self.plot(metric=Metric.CUM_REGRET if cumulative else Metric.REGRET)

plot_rewards(cumulative=True)

Generates a plot for the rewards or cumulative rewards metrics.

Parameters:

Name Type Description Default
cumulative bool

Whether to use the cumulative rewards.

True
Source code in mabby/stats.py
194
195
196
197
198
199
200
def plot_rewards(self, cumulative: bool = True) -> None:
    """Generates a plot for the rewards or cumulative rewards metrics.

    Args:
        cumulative: Whether to use the cumulative rewards.
    """
    self.plot(metric=Metric.CUM_REWARDS if cumulative else Metric.REWARDS)