src.tests.torf_gva_layer_test
Classes
- class src.tests.torf_gva_layer_test.TORFGVALayerTest
- Author:
Alberto M. Esmoris Pena
Correctness test for
TORFGVALayer. Validates that the mega-fused einsum'bkg,bkge,bkge->bge'produces outputs that match the legacy Reshape + Multiply + Multiply + ReduceSum tail of the Grouped Vector Attention block on the benchmark shape configurations (smallandlarge). Also verifies theget_config/from_configserialization round-trip, checks that a functionalkeras.Modelwrapping the layer builds without errors, and exercises the shape-validation path inbuild.- __init__()
Basic configuration for any VL3D test.
- Parameters:
name (str) – Test name
- run()
Run the full correctness and serialization test suite.
- Returns:
True if every configuration passes, False on the first failure.
- Return type:
bool
- run_case(name, b, k, g, dg)
Compare the fused
TORFGVALayeroutput against the legacy four-op baseline for a single shape configuration.- Parameters:
name – Human-readable shape tag (small/large).
b – Batch size.
k – Neighbor count.
g – Group count.
dg – Per-group sub-channel dimensionality.
- Returns:
True on numerical agreement within
self.eps.- Return type:
bool
- baseline_output(attn, w, v_mod)
Compute the legacy Reshape + Multiply + Multiply + ReduceSum tail output.
- Parameters:
attn – Attention weights tensor (B, K, G).
w – Weight-MLP output tensor (B, K, G, d_g).
v_mod – Modulated values tensor (B, K, G, d_g).
- Returns:
Aggregated context (B, G, d_g).
- fused_output(attn, w, v_mod)
Compute the
TORFGVALayeroutput on CPU for a deterministic comparison with the baseline.- Parameters:
attn – Attention weights tensor (B, K, G).
w – Weight-MLP output tensor (B, K, G, d_g).
v_mod – Modulated values tensor (B, K, G, d_g).
- Returns:
Aggregated context (B, G, d_g).
- run_serialization()
Check that
get_configfollowed byfrom_configyields a layer whose outputs are bitwise-identical to the original on a shared input.- Returns:
True on identical outputs.
- Return type:
bool
- run_graph_build()
Build a small functional
keras.Modelwrapping the layer and verify a single forward pass succeeds.- Returns:
True when the model builds and produces an output tensor with the expected shape.
- Return type:
bool
- run_shape_validation()
Confirm that shape-validation errors in
TORFGVALayer.build()raisesrc.model.deeplearn.deep_learning_exception.DeepLearningException.- Returns:
True when every invalid-shape combination raises.
- Return type:
bool