Skip to content

utils

Provides commonly used utility functions.

random_argmax(values, rng)

Computes random argmax of an array.

If there are multiple maximums, the index of one is chosen at random.

Parameters:

Name Type Description Default
values ArrayLike

An input array.

required
rng Generator

A random number generator.

required

Returns:

Type Description
int

The random argmax of the input array.

Source code in mabby/utils.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def random_argmax(values: ArrayLike, rng: Generator) -> int:
    """Computes random argmax of an array.

    If there are multiple maximums, the index of one is chosen at random.

    Args:
        values: An input array.
        rng: A random number generator.

    Returns:
        The random argmax of the input array.
    """
    candidates = np.where(values == np.max(values))[0]
    return int(rng.choice(candidates))