src.model.deeplearn.loss.class_weighted_focal_categorical_crossentropy

Functions

vl3d_class_weighted_focal_categorical_crossentropy(...)

Function to compute a class-weighted focal categorical cross-entropy loss.

src.model.deeplearn.loss.class_weighted_focal_categorical_crossentropy.vl3d_class_weighted_focal_categorical_crossentropy(class_weight, gamma)

Function to compute a class-weighted focal categorical cross-entropy loss.

The focal categorical cross-entropy for a batch of \(B \in \mathbb{Z}_{>0}\) elements is defined as:

\[\mathcal{L} = - B^{-1} \sum_{b=1}^{B}{(1-z^*_b)^{\gamma} \ln(z^*_b)}\]

Where \(z^*_b \in (0, 1) \subset \mathbb{R}\) is the softmax value of the neural network corresponding to the reference class for the \(b\)-th element of the batch (i.e., the predicted probability for the reference class).

The focusing parameter \(\gamma \in \mathbb{R}_{>1}\) governs how much clearly wrong predictions bring to the gradient compared to clearly correct predictions. Greater values of \(\gamma\) lead to ignoring correct predictions and considering wrong predictions for the gradient descent-based parameter update.

The class-weighted focal categorical cross-entropy loss can now be defined as:

\[\widetilde{\mathcal{L}} = - \left( \sum_{b=1}^{B}{\alpha^*_b} \right)^{-1}\sum_{b=1}^{B}{ \alpha^*_b (1-z^*_b)^{\gamma} \ln(z^*_b) }\]

Where \(\alpha^*_b \in \mathbb{R}\) is the class-weight corresponding to the reference class for the \(b\)-th element in the batch.

Finally, to make a point-wise class-weighted focal categorical cross-entropy we need to consider many points per batch such that:

\[\widehat{\mathcal{L}} = -\left( \sum_{b=1}^{B}\sum_{i=1}^{m}{\alpha_{bi}^*} \right)^{-1} \sum_{b=1}^{B}{\sum_{i=1}^{m}{ \alpha^*_{bi} (1-z^*_{bi})^{\gamma} \ln(z^*_{bi}) }}\]
Parameters:
  • class_weight – The vector of class weights. The component \(c\) of this vector (\(\pmb{\alpha}\)) is the weight for class \(c\).

  • gamma (float) – The focusing parameter \(\gamma \in \mathbb{R}_{>1}\).

Returns:

The class-weighted focal categorical cross-entropy loss.