src.model.deeplearn.layer.masked_batch_normalization
Classes
|
- class src.model.deeplearn.layer.masked_batch_normalization.MaskedBatchNormalization(*args, **kwargs)
- Author:
Alberto M. Esmoris Pena
Batch normalization that ignores padded rows when computing the batch statistics. Designed for use with
DLSparseConcatSequencerand the SpConv layer stack: every batch is statically padded along its row axis so thattf.functiontraces only once; the padded rows are zeros everywhere, which biases the running mean/variance of a vanillakeras.layers.BatchNormalizationtoward zero by the padding ratio.Forward contract:
Inputs are a 2-tuple
(x, mask)wherexhas shape(N, C)andmaskis a 1-D boolean tensor of shape(N,).mask[i] == Truemarks the i-th row ofxas a real cell that should contribute to the batch statistics;mask[i] == Falsemarks the i-th row as padding.During training, the mean and variance are computed only over the rows where
maskis True. The running mean and running variance are updated with these masked statistics following the standard exponential-moving-average rule with momentumself.momentum.During inference, the running statistics are used (vanilla BN semantics).
Every row of
xis normalized with the chosen statistics regardless of its mask value. The padded rows of the output are therefore well-defined real numbers; downstream operations that consume the output (the SpConvtf.gather+einsumpath) still zero them out via the ground-row gather, so the padded rows of the output are mathematically irrelevant in the end. Keeping the shape constant is what preserves thetf.functioncache.When called with a single tensor instead of a tuple, the layer falls back to standard (un-masked) batch normalization semantics. This is the path exercised by the unit tests when the mask is absent and by users who want to drop into the same class without the masking machinery.
The implementation is intentionally close to the Keras 3 stock
BatchNormalization: the parameter names andget_configschema mirror it so swapping is a one-line change.axisis fixed to-1(channel-last) — the SpConv layer stack always runs on rank-2(cells, channels)tensors.- Variables:
axis (int) – The axis to normalize over. Must be
-1for the SpConv use case; declared as a parameter so configs that already populate it round-trip cleanly.momentum (float) – Exponential-moving-average decay for the running statistics. Stock Keras default is 0.99; SpConv defaults to 0.9 / 0.99 depending on the wrap position.
epsilon (float) – Small float added to the variance for numerical stability.
center (bool) – When True the layer learns a per-channel
betaoffset.scale (bool) – When True the layer learns a per-channel
gammascale.
- __init__(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs)
Initialize the member attributes of the layer and the internal weights that do not depend on the input dimensionality.
- Parameters:
kwargs – The key-word specification to parametrize the layer.
- build(input_shape)
Logic to build the layer before the first call is executed.
This method can be overloaded by any derived class to either change or extend the logic.
- Parameters:
dim_in – The dimensionality of the input tensor.
- property moving_mean
First row of the packed
moving_statsweight.The slice tensor
moving_stats[0]is returned as a Tensor (not atf.Variable);bn.moving_mean.assign(new_mean)still works because TF’s variable-slice machinery writes the assignment back to the underlyingmoving_statsvariable. The Tensor-typed return matches the in-graph usage incall()and_masked_stats.
- property moving_variance
Second row of the packed
moving_statsweight.Same assignability property as
moving_mean—bn.moving_variance.assign(new_var)works via TF’s variable-slice machinery, despite the Tensor return type.
- call(inputs, training=False)
- Parameters:
inputs – Either
x(no mask) or the 2-tuple(x, mask).xis a rank-2 float tensor(cells, channels);maskis a 1-D boolean tensor(cells,).training – Whether the layer is being called in training mode. When True, statistics are computed from the current batch (with the mask applied) and the running statistics are updated. When False, the running statistics are used.
- Returns:
A tensor of the same shape as
xwith the (masked) batch normalization applied.
- compute_output_shape(input_shape)
- get_config()
Obtain the dictionary specifying how to serialize the layer.
- Returns:
The dictionary with the necessary data to serialize the layer.
- Return type:
dict