src.model.deeplearn.optimizer.torf_centralized_adam

Classes

TorfCentralizedAdam([gc])

class src.model.deeplearn.optimizer.torf_centralized_adam.TorfCentralizedAdam(gc=True, **kwargs)
Author:

Alberto M. Esmoris Pena

Adam optimizer with post-update gradient centralization for the TransfOctoRF pipeline.

Motivation. Standard gradient centralization (Yong et al., 2020) subtracts the mean from each gradient tensor before the optimizer processes it. When combined with Adam, this causes the running first moment \(m_t\) and second moment \(v_t\) to track the statistics of the centralized gradient rather than the original one. Since centralization reduces the variance of the gradient, Adam’s per-parameter adaptive learning rate \(\alpha / (\sqrt{\hat{v}_t} + \epsilon)\) overestimates the step size, destabilizing training.

This implementation applies centralization after Adam computes the full adaptive update, so the moments track the original gradient (correct adaptive rates) while the centralization constrains the update direction only.

Algorithm. Given gradient \(g_t\) for a weight tensor \(\mathbf{W}\) with shape \((d_1, d_2, \ldots, d_n)\) where \(n \geq 2\):

Step 1 — First moment (mean of gradients):

\[m_t = \beta_1 \, m_{t-1} + (1 - \beta_1) \, g_t\]

Step 2 — Second moment (mean of squared gradients):

\[v_t = \beta_2 \, v_{t-1} + (1 - \beta_2) \, g_t^2\]

Step 3 — Bias-corrected learning rate:

\[\alpha_t = \eta \, \frac{\sqrt{1 - \beta_2^t}}{1 - \beta_1^t}\]

where \(\eta\) is the base learning rate and \(t\) is the current iteration. This folds the bias correction into the step size rather than computing \(\hat{m}_t = m_t / (1 - \beta_1^t)\) and \(\hat{v}_t = v_t / (1 - \beta_2^t)\) separately.

AMSGrad variant (optional, amsgrad=True): the maximum of all past \(v_t\) is used instead:

\[\hat{v}_t = \max(\hat{v}_{t-1}, v_t)\]

In subsequent formulas, \(v_t\) is replaced by \(\hat{v}_t\) when AMSGrad is enabled.

Step 4 — Adaptive update (uncentralized):

\[\Delta \mathbf{W} = \frac{\alpha_t \, m_t}{\sqrt{v_t} + \epsilon}\]

where \(\epsilon\) is a small constant for numerical stability and all operations are element-wise.

Step 5 — Post-update centralization. The mean of \(\Delta \mathbf{W}\) is subtracted over all axes except the last (output) dimension:

\[\overline{\Delta W}_j = \frac{1}{\prod_{i=1}^{n-1} d_i} \sum_{i_1, \ldots, i_{n-1}} \Delta W_{i_1 \ldots i_{n-1} j}\]
\[\widehat{\Delta \mathbf{W}} = \Delta \mathbf{W} - \overline{\Delta \mathbf{W}}\]

This constrains the update to the hyperplane orthogonal to the all-ones vector along the input dimensions, preventing correlated weight drifts.

Step 6 — Parameter update:

\[\mathbf{W}_{t+1} = \mathbf{W}_t - \widehat{\Delta \mathbf{W}}\]

For bias vectors and other 1-D parameters (rank < 2), steps 5–6 reduce to the standard Adam update \(\mathbf{W}_{t+1} = \mathbf{W}_t - \Delta \mathbf{W}\) (no centralization).

When gc=False, the optimizer behaves identically to standard keras.optimizers.Adam.

Variables:

gc (bool) – Enable post-update gradient centralization.

__init__(gc=True, **kwargs)

Initialize TorfCentralizedAdam.

Parameters:

gc (bool) – Enable gradient centralization.

update_step(gradient, variable, learning_rate)

Adam update with post-step centralization.

Parameters:
  • gradient – Gradient tensor for this variable.

  • variable – The weight variable to update.

  • learning_rate – Current learning rate.