src.model.deeplearn.layer.masked_batch_normalization

Classes

MaskedBatchNormalization(*args, **kwargs)

class src.model.deeplearn.layer.masked_batch_normalization.MaskedBatchNormalization(*args, **kwargs)
Author:

Alberto M. Esmoris Pena

Batch normalization that ignores padded rows when computing the batch statistics. Designed for use with DLSparseConcatSequencer and the SpConv layer stack: every batch is statically padded along its row axis so that tf.function traces only once; the padded rows are zeros everywhere, which biases the running mean/variance of a vanilla keras.layers.BatchNormalization toward zero by the padding ratio.

Forward contract:

  • Inputs are a 2-tuple (x, mask) where x has shape (N, C) and mask is a 1-D boolean tensor of shape (N,). mask[i] == True marks the i-th row of x as a real cell that should contribute to the batch statistics; mask[i] == False marks the i-th row as padding.

  • During training, the mean and variance are computed only over the rows where mask is True. The running mean and running variance are updated with these masked statistics following the standard exponential-moving-average rule with momentum self.momentum.

  • During inference, the running statistics are used (vanilla BN semantics).

  • Every row of x is normalized with the chosen statistics regardless of its mask value. The padded rows of the output are therefore well-defined real numbers; downstream operations that consume the output (the SpConv tf.gather + einsum path) still zero them out via the ground-row gather, so the padded rows of the output are mathematically irrelevant in the end. Keeping the shape constant is what preserves the tf.function cache.

  • When called with a single tensor instead of a tuple, the layer falls back to standard (un-masked) batch normalization semantics. This is the path exercised by the unit tests when the mask is absent and by users who want to drop into the same class without the masking machinery.

The implementation is intentionally close to the Keras 3 stock BatchNormalization: the parameter names and get_config schema mirror it so swapping is a one-line change. axis is fixed to -1 (channel-last) — the SpConv layer stack always runs on rank-2 (cells, channels) tensors.

Variables:
  • axis (int) – The axis to normalize over. Must be -1 for the SpConv use case; declared as a parameter so configs that already populate it round-trip cleanly.

  • momentum (float) – Exponential-moving-average decay for the running statistics. Stock Keras default is 0.99; SpConv defaults to 0.9 / 0.99 depending on the wrap position.

  • epsilon (float) – Small float added to the variance for numerical stability.

  • center (bool) – When True the layer learns a per-channel beta offset.

  • scale (bool) – When True the layer learns a per-channel gamma scale.

__init__(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs)

Initialize the member attributes of the layer and the internal weights that do not depend on the input dimensionality.

Parameters:

kwargs – The key-word specification to parametrize the layer.

build(input_shape)

Logic to build the layer before the first call is executed.

This method can be overloaded by any derived class to either change or extend the logic.

Parameters:

dim_in – The dimensionality of the input tensor.

property moving_mean

First row of the packed moving_stats weight.

The slice tensor moving_stats[0] is returned as a Tensor (not a tf.Variable); bn.moving_mean.assign(new_mean) still works because TF’s variable-slice machinery writes the assignment back to the underlying moving_stats variable. The Tensor-typed return matches the in-graph usage in call() and _masked_stats.

property moving_variance

Second row of the packed moving_stats weight.

Same assignability property as moving_meanbn.moving_variance.assign(new_var) works via TF’s variable-slice machinery, despite the Tensor return type.

call(inputs, training=False)
Parameters:
  • inputs – Either x (no mask) or the 2-tuple (x, mask). x is a rank-2 float tensor (cells, channels); mask is a 1-D boolean tensor (cells,).

  • training – Whether the layer is being called in training mode. When True, statistics are computed from the current batch (with the mask applied) and the running statistics are updated. When False, the running statistics are used.

Returns:

A tensor of the same shape as x with the (masked) batch normalization applied.

compute_output_shape(input_shape)
get_config()

Obtain the dictionary specifying how to serialize the layer.

Returns:

The dictionary with the necessary data to serialize the layer.

Return type:

dict

classmethod from_config(config)

Deserialize a layer from given specification.

Parameters:

config – The dictionary specifying how to deserialize the layer.

Returns:

The deserialized layer.

Return type:

Layer or derived