stats
Provides metric tracking for multi-armed bandit simulations.
AgentStats(agent, bandit, steps, metrics=None)
Statistics for an agent in a multi-armed bandit simulation.
All available metrics are tracked by default. Alternatively, a specific list can
be specified through the metrics
argument.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent |
Agent
|
The agent that statistics are tracked for |
required |
bandit |
Bandit
|
The bandit of the simulation being run |
required |
steps |
int
|
The number of steps per trial in the simulation |
required |
metrics |
Iterable[Metric] | None
|
A collection of metrics to track. |
None
|
Source code in mabby/stats.py
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
|
__getitem__(metric)
Gets values for a metric.
If the metric is not a base metric, the values are automatically transformed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric |
Metric
|
The metric to get the values for. |
required |
Returns:
Type | Description |
---|---|
NDArray[np.float64]
|
An array of values for the metric. |
Source code in mabby/stats.py
236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
|
__len__()
Returns the number of steps each trial is tracked for.
Source code in mabby/stats.py
232 233 234 |
|
update(step, choice, reward)
Updates metric values for the latest simulation step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
int
|
The number of the step. |
required |
choice |
int
|
The choice made by the agent. |
required |
reward |
float
|
The reward observed by the agent. |
required |
Source code in mabby/stats.py
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
|
Metric(label, base=None, transform=None)
Bases: Enum
Enum for metrics that simulations can track.
Metrics can be derived from other metrics through specifying a base
metric
and a transform
function. This is useful for things like defining cumulative
versions of an existing metric, where the transformed values can be computed
"lazily" instead of being redundantly stored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label |
str
|
Verbose name of the metric (title case) |
required |
base |
str | None
|
Name of the base metric |
None
|
transform |
Callable[[NDArray[np.float64]], NDArray[np.float64]] | None
|
Transformation function from the base metric |
None
|
Source code in mabby/stats.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
base: Metric
property
The base metric that the metric is transformed from.
If the metric is already a base metric, the metric itself is returned.
__repr__()
Returns the verbose name of the metric.
Source code in mabby/stats.py
71 72 73 |
|
is_base()
Returns whether the metric is a base metric.
Returns:
Type | Description |
---|---|
bool
|
|
Source code in mabby/stats.py
75 76 77 78 79 80 81 |
|
map_to_base(metrics)
classmethod
Traces all metrics back to their base metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metrics |
Iterable[Metric]
|
A collection of metrics. |
required |
Returns:
Type | Description |
---|---|
Iterable[Metric]
|
A set containing the base metrics of all the input metrics. |
Source code in mabby/stats.py
93 94 95 96 97 98 99 100 101 102 103 |
|
transform(values)
Transforms values from the base metric.
If the metric is already a base metric, the input values are returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
NDArray[np.float64]
|
An array of input values for the base metric. |
required |
Returns:
Type | Description |
---|---|
NDArray[np.float64]
|
An array of transformed values for the metric. |
Source code in mabby/stats.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
|
MetricMapping
dataclass
SimulationStats(simulation)
Statistics for a multi-armed bandit simulation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
simulation |
Simulation
|
The simulation to track. |
required |
Source code in mabby/stats.py
124 125 126 127 128 129 130 131 |
|
__contains__(agent)
Returns if an agent's statistics are present.
Returns:
Type | Description |
---|---|
bool
|
|
Source code in mabby/stats.py
163 164 165 166 167 168 169 |
|
__getitem__(agent)
Gets statistics for an agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent |
Agent
|
The agent to get the statistics of. |
required |
Returns:
Type | Description |
---|---|
AgentStats
|
The statistics of the agent. |
Source code in mabby/stats.py
141 142 143 144 145 146 147 148 149 150 |
|
__setitem__(agent, agent_stats)
Sets the statistics for an agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent |
Agent
|
The agent to set the statistics of. |
required |
agent_stats |
AgentStats
|
The agent statistics to set. |
required |
Source code in mabby/stats.py
152 153 154 155 156 157 158 159 160 161 |
|
add(agent_stats)
Adds statistics for an agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
agent_stats |
AgentStats
|
The agent statistics to add. |
required |
Source code in mabby/stats.py
133 134 135 136 137 138 139 |
|
plot(metric)
Generates a plot for a simulation metric.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric |
Metric
|
The metric to plot. |
required |
Source code in mabby/stats.py
171 172 173 174 175 176 177 178 179 180 |
|
plot_optimality()
Generates a plot for the optimality metric.
Source code in mabby/stats.py
190 191 192 |
|
plot_regret(cumulative=True)
Generates a plot for the regret or cumulative regret metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cumulative |
bool
|
Whether to use the cumulative regret. |
True
|
Source code in mabby/stats.py
182 183 184 185 186 187 188 |
|
plot_rewards(cumulative=True)
Generates a plot for the rewards or cumulative rewards metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cumulative |
bool
|
Whether to use the cumulative rewards. |
True
|
Source code in mabby/stats.py
194 195 196 197 198 199 200 |
|