src.model.deeplearn.loss.class_weighted_focal_binary_crossentropy
Functions
Function to compute a class-weighted focal binary cross-entropy loss. |
- src.model.deeplearn.loss.class_weighted_focal_binary_crossentropy.vl3d_class_weighted_focal_binary_crossentropy(class_weight, gamma)
Function to compute a class-weighted focal binary cross-entropy loss.
The focal binary cross-entropy for a batch of \(B \in \mathbb{Z}_{>0}\) elements is defined as:
\[\mathcal{L} = - B^{-1} \sum_{b=1}^{B}{(1-z^*_b)^{\gamma} \ln(z^*_b)}\]Where \(z^*_b \in (0, 1) \subset \mathbb{R}\) is the sigmoid value of the neural network corresponding to the reference class for the \(b\)-th element of the batch. That is, \(z^*_b = y_b \hat{y}_b + (1 - y_b)(1 - \hat{y}_b)\), where \(y_b \in \{0, 1\}\) is the true label and \(\hat{y}_b \in (0, 1)\) is the predicted probability.
The focusing parameter \(\gamma \in \mathbb{R}_{>1}\) governs how much clearly wrong predictions bring to the gradient compared to clearly correct predictions. Greater values of \(\gamma\) lead to ignoring correct predictions and considering wrong predictions for the gradient descent-based parameter update.
The class-weighted focal binary cross-entropy loss can now be defined as:
\[\widetilde{\mathcal{L}} = - \left( \sum_{b=1}^{B}{\alpha^*_b} \right)^{-1}\sum_{b=1}^{B}{ \alpha^*_b (1-z^*_b)^{\gamma} \ln(z^*_b) }\]Where \(\alpha^*_b \in \mathbb{R}\) is the class-weight corresponding to the reference class for the \(b\)-th element in the batch. That is, \(\alpha^*_b = y_b w_1 + (1 - y_b) w_0\), where \(\pmb{w} \in \mathbb{R}^2\) is the vector of class weights.
Finally, to make a point-wise class-weighted focal binary cross-entropy we need to consider many points per batch such that:
\[\widehat{\mathcal{L}} = -\left( \sum_{b=1}^{B}\sum_{i=1}^{m}{\alpha_{bi}^*} \right)^{-1} \sum_{b=1}^{B}{\sum_{i=1}^{m}{ \alpha^*_{bi} (1-z^*_{bi})^{\gamma} \ln(z^*_{bi}) }}\]- Parameters:
class_weight – The vector of class weights. The component \(c\) of this vector (\(\pmb{\alpha}\)) is the weight for class \(c\).
gamma (float) – The focusing parameter \(\gamma \in \mathbb{R}_{>1}\).
- Returns:
The class-weighted focal binary cross-entropy loss.