src.tests.masked_batch_normalization_test
Classes
- class src.tests.masked_batch_normalization_test.MaskedBatchNormalizationTest
- Author:
Alberto M. Esmoris Pena
Exhaustive unit test for
MaskedBatchNormalization. The layer corrects the BN-padded-zero bias introduced by the static-shape padding inDLSparseConcatSequencer. These tests verify the correctness of the masked-statistics path independently of the SpConv layer stack, so a future regression in the masking math is caught here before it can confuse a full-pipeline run.The subtests cover:
Vanilla BN parity — when called without a mask the layer must reproduce the stock
keras.layers.BatchNormalizationoutputs to within floating-point tolerance.All-True mask parity — supplying an all-True mask must produce the same output as the vanilla path (the mask is a no-op).
Half mask — when half of the rows are padded, the computed batch mean and variance must match the ground-truth mean and variance of the real-cell sub-tensor, and the normalization of real-cell positions must equal what they would have received as a stand-alone (un-padded) batch.
Single-real-cell — when exactly one row is unmasked the batch mean is that row and the batch variance is zero; normalization is then
0 / sqrt(eps)for that row.Heterogeneous real distribution — with real rows drawn from N(mu, sigma^2) and many padded zero rows, the layer must recover mu and sigma^2 (within the EMA momentum) and leave the running statistics in a position where eval-mode normalization produces clean activations.
Padded-rows-still-normalized — every output row, including padded ones, must be finite. Padded rows receive the same affine transform as everyone else; downstream gather-based layers are what zero them out.
Running stats EMA — after N training calls with constant data, the running stats must converge to the masked-batch statistics with the documented exponential rate.
Eval-mode uses running stats only — calling with training=False must use the moving stats and must not update them.
Empty mask — when every row is padded the call must not crash, must not pollute the running stats, and must return a tensor of the right shape (defensive fallback path).
Numerical stability — huge values — feeding rows with very large magnitudes must not overflow the variance computation.
Numerical stability — tiny values — feeding rows that are near zero everywhere must not underflow / divide by zero (the
epsilonguards this).Center / scale toggles — disabling
center/scalemust remove the corresponding trainable weight and the affine term must not be applied.Serialization round-trip —
get_config/from_configmust reproduce all the configurable fields.Gradient flow on real cells — a gradient w.r.t. a real cell must be non-zero in the gamma/beta direction; padded cells must contribute no gradient to gamma/beta (they were not used to compute the stats).
Bool vs int32 mask — both dtypes are accepted and produce identical outputs.
- __init__()
Basic configuration for any VL3D test.
- Parameters:
name (str) – Test name
- run()
Run the test.
- Returns:
True if test is successfully passed, False otherwise.
- Return type:
bool
- subtest_vanilla_bn_parity()
Without a mask the layer matches the stock Keras
BatchNormalization. Tolerance keeps a generous margin because the two implementations use slightly different variance estimators (we use population variance; Keras BN also uses population variance — they should match bit-exact-or-close).
- subtest_all_true_mask_parity()
Mask of all True == no mask: outputs must be identical.
- subtest_half_mask()
Half the rows are padding. The masked stats must equal the stats over the real half; the real rows’ outputs must equal what they would have received standalone.
- subtest_single_real_cell()
With one real row the batch mean is that row and the batch variance is exactly zero.
- subtest_heterogeneous_real_distribution()
Real rows drawn from N(mu, sigma^2) plus many padded zeros. The layer must recover mu and sigma^2 in the running stats (with momentum=0 they are set in a single call). Without masking, the running stats would be biased toward zero by the padding ratio.
- subtest_padded_rows_finite()
Even though padded rows are masked out of the statistics, the layer normalizes every row with the masked stats so the output is dense. All output rows must be finite real numbers (no NaN, no Inf).
- subtest_running_stats_ema()
With non-zero momentum the running stats converge to the masked stats following the documented EMA rule.
- subtest_eval_uses_running()
Inference (training=False) must not update the running stats and must use them in the normalization.
- subtest_empty_mask()
All rows padded. The layer must (a) not crash, (b) keep the running stats unchanged (defensive fallback), (c) return a finite tensor of the right shape.
- subtest_numerical_huge()
Large magnitudes must not overflow.
- subtest_numerical_tiny()
Near-zero inputs must not underflow / divide by zero.
epsiloncovers this.
- subtest_center_scale_toggles()
Disabling center / scale removes the corresponding trainable weight and skips the affine application.
- subtest_serialization_roundtrip()
get_config / from_config must round-trip every configurable knob.
- subtest_gradient_flow()
Gradient w.r.t. gamma / beta must be non-zero on a training forward pass with real rows present. Padded rows must not contribute to gamma / beta gradients (because they were excluded from the masked stats computation).
- subtest_bool_vs_int_mask()
The mask must accept bool or int32; results must be identical.
- subtest_fp16_input_mixed_precision()
Mixed-precision contract: when the input arrives as
float16(the upstream layers ran undermixed_float16), the layer must:Compute moments in
float32(so the variance reduction does not overflow / lose precision).Update the
moving_statsweight infloat32(the running stats stay precise across many batches).Return an output tensor in the input’s dtype (
float16) so downstream ops see the expected dtype.
Production-ready mixed-precision support hinges on this contract — without it the variance reduction silently collapses to
0undermixed_float16and downstream normalisation produces NaN or wildly miscalibrated outputs.
- subtest_bf16_input_mixed_precision()
bfloat16 mixed-precision coverage.
bf16 has fp32’s exponent range (no overflow risk) but only a 7-bit mantissa — so the precision failure mode in
(x - mean)**2is the binding concern. Mirror the fp16 subtest’s contract:Moments computed in fp32.
moving_statsstays fp32 with finite values.Output dtype matches the input (bf16).
Gated by a try/except around the policy/dtype construction so the subtest auto-skips on hardware where bf16 is unavailable in the local TF build.