Skip to content

arms

Provides Arm base class with some common reward distributions.

Arm(**kwargs)

Bases: ABC, EnforceOverrides

Base class for a bandit arm implementing a reward distribution.

An arm represents one of the decision choices available to the agent in a bandit problem. It has a hidden reward distribution and can be played by the agent to generate observable rewards.

Source code in mabby/arms.py
21
22
23
@abstractmethod
def __init__(self, **kwargs: float):
    """Initializes an arm."""

mean: float property abstractmethod

The mean reward of the arm.

Returns:

Type Description
float

The computed mean of the arm's reward distribution.

__repr__() abstractmethod

Returns the string representation of the arm.

Source code in mabby/arms.py
45
46
47
@abstractmethod
def __repr__(self) -> str:
    """Returns the string representation of the arm."""

bandit(rng=None, seed=None, **kwargs) classmethod

Creates a bandit with arms of the same reward distribution type.

Parameters:

Name Type Description Default
rng Generator | None

A random number generator.

None
seed int | None

A seed for random number generation if rng is not provided.

None
**kwargs list[float]

A dictionary where keys are arm parameter names and values are lists of parameter values for each arm.

{}

Returns:

Type Description
Bandit

A bandit with the specified arms.

Source code in mabby/arms.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@classmethod
def bandit(
    cls,
    rng: Generator | None = None,
    seed: int | None = None,
    **kwargs: list[float],
) -> Bandit:
    """Creates a bandit with arms of the same reward distribution type.

    Args:
        rng: A random number generator.
        seed: A seed for random number generation if ``rng`` is not provided.
        **kwargs: A dictionary where keys are arm parameter names and values are
            lists of parameter values for each arm.

    Returns:
        A bandit with the specified arms.
    """
    params_dicts = [dict(zip(kwargs, t)) for t in zip(*kwargs.values())]
    if len(params_dicts) == 0:
        raise ValueError("insufficient parameters to create an arm")
    return Bandit([cls(**params) for params in params_dicts], rng, seed)

play(rng) abstractmethod

Plays the arm and samples a reward.

Parameters:

Name Type Description Default
rng Generator

A random number generator.

required

Returns:

Type Description
float

The sampled reward from the arm's reward distribution.

Source code in mabby/arms.py
25
26
27
28
29
30
31
32
33
34
@abstractmethod
def play(self, rng: Generator) -> float:
    """Plays the arm and samples a reward.

    Args:
        rng: A random number generator.

    Returns:
        The sampled reward from the arm's reward distribution.
    """

BernoulliArm(p)

Bases: Arm

Bandit arm with a Bernoulli reward distribution.

Parameters:

Name Type Description Default
p float

Parameter of the Bernoulli distribution.

required
Source code in mabby/arms.py
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(self, p: float):
    """Initializes a Bernoulli arm.

    Args:
        p: Parameter of the Bernoulli distribution.
    """
    if p < 0 or p > 1:
        raise ValueError(
            f"float {str(p)} is not a valid probability for Bernoulli distribution"
        )

    self.p: float = p  #: Parameter of the Bernoulli distribution

GaussianArm(loc, scale)

Bases: Arm

Bandit arm with a Gaussian reward distribution.

Parameters:

Name Type Description Default
loc float

Mean ("center") of the Gaussian distribution.

required
scale float

Standard deviation of the Gaussian distribution.

required
Source code in mabby/arms.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(self, loc: float, scale: float):
    """Initializes a Gaussian arm.

    Args:
        loc: Mean ("center") of the Gaussian distribution.
        scale: Standard deviation of the Gaussian distribution.
    """
    if scale < 0:
        raise ValueError(
            f"float {str(scale)} is not a valid scale for Gaussian distribution"
        )

    self.loc: float = loc  #: Mean ("center") of the Gaussian distribution
    self.scale: float = scale  #: Standard deviation of the Gaussian distribution