src.utils.preds.argmax_pred_select_strategy

Classes

ArgMaxPredSelectStrategy([disabled_classes])

class src.utils.preds.argmax_pred_select_strategy.ArgMaxPredSelectStrategy(disabled_classes=None, **kwargs)
Author:

Alberto M. Esmoris Pena

Select the index of the max prediction from reduced predictions.

The selected prediction for the \(i-th\) points assuming \(K\) predicted values (e.g., likelihoods for classifications) will be:

\[y_{i} = \operatorname*{argmax}_{0 \leq k < K} \quad z_{ik}\]

Note that when a single value is given the selection will consider the value round to the closest integer such that:

\[y_{i} = \lfloor{z_i}\rceil\]

The optional disabled_classes kwarg lets the caller forbid one or more class indices from ever being selected — those columns are masked to -inf before the argmax. This is the cleanest way to “delete” a class from the model’s effective output space without re-training (e.g., the BN-padding-induced “unclassified” sink that biases the SpConv stack toward predicting class 0 for low-signal cells).

See PredSelectStrategy.

__init__(disabled_classes=None, **kwargs)

Initialize/instantiate an argmax prediction selection strategy.

Parameters:
  • disabled_classes – Optional iterable of class indices to mask out before the argmax. Defaults to no masking.

  • kwargs – The attributes for the ArgmaxPredSelectStrategy.

select(reducer, Z)

See PredSelectStrategy and PredSelectStrategy.select().

static is_single_value(Z)

Check whether the reduced predictions consist of a single scalar per point (True) or not (False).

Parameters:

Z – The reduced predictions to be checked.

Returns:

True if the reduced predictions consist of a single scalar, False otherwise.

Return type:

bool

static round_to_closest_int(Z)

Round each reduced prediction to its closest integer.

Parameters:

Z – The reduced predictions to be checked.

Returns:

Each prediction rounded to its closest integer.