src.model.deeplearn.optimizer.torf_centralized_sgd

Classes

TorfCentralizedSGD([gc])

class src.model.deeplearn.optimizer.torf_centralized_sgd.TorfCentralizedSGD(gc=True, **kwargs)
Author:

Alberto M. Esmoris Pena

SGD optimizer with post-update gradient centralization for the TransfOctoRF pipeline.

Applies the same post-update centralization strategy as TorfCentralizedAdam: the mean is subtracted from the computed update (not the raw gradient), preserving the momentum buffer statistics.

Centralization operator. For a weight tensor \(\mathbf{W}\) with shape \((d_1, d_2, \ldots, d_n)\) where \(n \geq 2\), the centralization operator \(\mathcal{C}\) subtracts the mean over all axes except the last:

\[\mathcal{C}(\Delta \mathbf{W})_{i_1 \ldots i_n} = \Delta W_{i_1 \ldots i_n} - \frac{1}{\prod_{k=1}^{n-1} d_k} \sum_{j_1, \ldots, j_{n-1}} \Delta W_{j_1 \ldots j_{n-1} i_n}\]

For rank-1 tensors (biases), \(\mathcal{C}\) is the identity.

Case 1 — No momentum (\(\mu = 0\)):

\[\Delta \mathbf{W} = \eta \, g_t, \quad \mathbf{W}_{t+1} = \mathbf{W}_t - \mathcal{C}(\Delta \mathbf{W})\]

Case 2 — Classical momentum (\(\mu > 0\), nesterov=False):

The momentum buffer \(m_t\) accumulates the velocity:

\[m_t = \mu \, m_{t-1} - \eta \, g_t\]

The centralized update is applied additively (the sign is already embedded in \(m_t\)):

\[\mathbf{W}_{t+1} = \mathbf{W}_t + \mathcal{C}(m_t)\]

Case 3 — Nesterov momentum (\(\mu > 0\), nesterov=True):

The momentum buffer is updated identically to Case 2. The Nesterov look-ahead computes the update as:

\[\Delta \mathbf{W} = \mu \, m_t - \eta \, g_t\]
\[\mathbf{W}_{t+1} = \mathbf{W}_t + \mathcal{C}(\Delta \mathbf{W})\]

In all cases, centralization is applied to the final update rather than to the raw gradient, so the momentum buffer \(m_t\) tracks the original (uncentralized) gradient-derived velocity.

When gc=False, the optimizer behaves identically to standard keras.optimizers.SGD.

Variables:

gc (bool) – Enable post-update gradient centralization.

__init__(gc=True, **kwargs)

Initialize TorfCentralizedSGD.

Parameters:

gc (bool) – Enable gradient centralization.

centralize(update, variable)

Apply the centralization operator \(\mathcal{C}\) to an update tensor.

For rank \(\\geq 2\), subtracts the mean over all axes except the last. For rank < 2, returns the update unchanged.

Parameters:
  • update – The weight update tensor.

  • variable – The associated weight variable.

Returns:

Centralized update.

update_step(gradient, variable, learning_rate)

SGD update with post-step centralization.

Parameters:
  • gradient – Gradient tensor for this variable.

  • variable – The weight variable to update.

  • learning_rate – Current learning rate.