src.tests.torf_gva_layer_test

Classes

TORFGVALayerTest()

class src.tests.torf_gva_layer_test.TORFGVALayerTest
Author:

Alberto M. Esmoris Pena

Correctness test for TORFGVALayer. Validates that the mega-fused einsum 'bkg,bkge,bkge->bge' produces outputs that match the legacy Reshape + Multiply + Multiply + ReduceSum tail of the Grouped Vector Attention block on the benchmark shape configurations (small and large). Also verifies the get_config / from_config serialization round-trip, checks that a functional keras.Model wrapping the layer builds without errors, and exercises the shape-validation path in build.

__init__()

Basic configuration for any VL3D test.

Parameters:

name (str) – Test name

run()

Run the full correctness and serialization test suite.

Returns:

True if every configuration passes, False on the first failure.

Return type:

bool

run_case(name, b, k, g, dg)

Compare the fused TORFGVALayer output against the legacy four-op baseline for a single shape configuration.

Parameters:
  • name – Human-readable shape tag (small/large).

  • b – Batch size.

  • k – Neighbor count.

  • g – Group count.

  • dg – Per-group sub-channel dimensionality.

Returns:

True on numerical agreement within self.eps.

Return type:

bool

baseline_output(attn, w, v_mod)

Compute the legacy Reshape + Multiply + Multiply + ReduceSum tail output.

Parameters:
  • attn – Attention weights tensor (B, K, G).

  • w – Weight-MLP output tensor (B, K, G, d_g).

  • v_mod – Modulated values tensor (B, K, G, d_g).

Returns:

Aggregated context (B, G, d_g).

fused_output(attn, w, v_mod)

Compute the TORFGVALayer output on CPU for a deterministic comparison with the baseline.

Parameters:
  • attn – Attention weights tensor (B, K, G).

  • w – Weight-MLP output tensor (B, K, G, d_g).

  • v_mod – Modulated values tensor (B, K, G, d_g).

Returns:

Aggregated context (B, G, d_g).

run_serialization()

Check that get_config followed by from_config yields a layer whose outputs are bitwise-identical to the original on a shared input.

Returns:

True on identical outputs.

Return type:

bool

run_graph_build()

Build a small functional keras.Model wrapping the layer and verify a single forward pass succeeds.

Returns:

True when the model builds and produces an output tensor with the expected shape.

Return type:

bool

run_shape_validation()

Confirm that shape-validation errors in TORFGVALayer.build() raise src.model.deeplearn.deep_learning_exception.DeepLearningException.

Returns:

True when every invalid-shape combination raises.

Return type:

bool