src.model.deeplearn.optimizer.torf_centralized_sgd
Classes
|
- class src.model.deeplearn.optimizer.torf_centralized_sgd.TorfCentralizedSGD(gc=True, **kwargs)
- Author:
Alberto M. Esmoris Pena
SGD optimizer with post-update gradient centralization for the TransfOctoRF pipeline.
Applies the same post-update centralization strategy as
TorfCentralizedAdam: the mean is subtracted from the computed update (not the raw gradient), preserving the momentum buffer statistics.Centralization operator. For a weight tensor \(\mathbf{W}\) with shape \((d_1, d_2, \ldots, d_n)\) where \(n \geq 2\), the centralization operator \(\mathcal{C}\) subtracts the mean over all axes except the last:
\[\mathcal{C}(\Delta \mathbf{W})_{i_1 \ldots i_n} = \Delta W_{i_1 \ldots i_n} - \frac{1}{\prod_{k=1}^{n-1} d_k} \sum_{j_1, \ldots, j_{n-1}} \Delta W_{j_1 \ldots j_{n-1} i_n}\]For rank-1 tensors (biases), \(\mathcal{C}\) is the identity.
Case 1 — No momentum (\(\mu = 0\)):
\[\Delta \mathbf{W} = \eta \, g_t, \quad \mathbf{W}_{t+1} = \mathbf{W}_t - \mathcal{C}(\Delta \mathbf{W})\]Case 2 — Classical momentum (\(\mu > 0\),
nesterov=False):The momentum buffer \(m_t\) accumulates the velocity:
\[m_t = \mu \, m_{t-1} - \eta \, g_t\]The centralized update is applied additively (the sign is already embedded in \(m_t\)):
\[\mathbf{W}_{t+1} = \mathbf{W}_t + \mathcal{C}(m_t)\]Case 3 — Nesterov momentum (\(\mu > 0\),
nesterov=True):The momentum buffer is updated identically to Case 2. The Nesterov look-ahead computes the update as:
\[\Delta \mathbf{W} = \mu \, m_t - \eta \, g_t\]\[\mathbf{W}_{t+1} = \mathbf{W}_t + \mathcal{C}(\Delta \mathbf{W})\]In all cases, centralization is applied to the final update rather than to the raw gradient, so the momentum buffer \(m_t\) tracks the original (uncentralized) gradient-derived velocity.
When
gc=False, the optimizer behaves identically to standardkeras.optimizers.SGD.- Variables:
gc (bool) – Enable post-update gradient centralization.
- __init__(gc=True, **kwargs)
Initialize TorfCentralizedSGD.
- Parameters:
gc (bool) – Enable gradient centralization.
- centralize(update, variable)
Apply the centralization operator \(\mathcal{C}\) to an update tensor.
For rank \(\\geq 2\), subtracts the mean over all axes except the last. For rank < 2, returns the update unchanged.
- Parameters:
update – The weight update tensor.
variable – The associated weight variable.
- Returns:
Centralized update.
- update_step(gradient, variable, learning_rate)
SGD update with post-step centralization.
- Parameters:
gradient – Gradient tensor for this variable.
variable – The weight variable to update.
learning_rate – Current learning rate.