src.model.deeplearn.layer.torf_gva_layer

Classes

TORFGVALayer(*args, **kwargs)

class src.model.deeplearn.layer.torf_gva_layer.TORFGVALayer(*args, **kwargs)
Author:

Alberto M. Esmoris Pena

Mega-fused tail of a Grouped Vector Attention (GVA) block. A single einsum contracts the attention weights, the weight-MLP output, and the modulated values into the per-group aggregated context, replacing the Reshape + Multiply + Multiply + ReduceSum quartet that otherwise materializes a full \((B, K, G, d_g)\) intermediate in VRAM before the \(K\)-reduction.

Given:

  • \(\boldsymbol{\alpha} \in \mathbb{R}^{B \times K \times G}\) — softmax attention weights per (sample, neighbor, group).

  • \(\mathbf{w} \in \mathbb{R}^{B \times K \times G \times d_g}\) — weight-MLP output per (sample, neighbor, group, sub-channel).

  • \(\mathbf{V}_{\text{mod}} \in \mathbb{R}^{B \times K \times G \times d_g}\) — modulated values \(\mathbf{V} + \mathbf{PE}\).

The layer computes:

\[\mathbf{A}_{b g e} = \sum_{k=1}^{K} \alpha_{b k g} \, w_{b k g e} \, (V_{\text{mod}})_{b k g e}\]

which is identical (up to float reduction reassociation) to the sequence:

\[\mathbf{A} = \sum_{k} (\alpha_{b k g} \cdot w_{b k g e}) \odot (V_{\text{mod}})_{b k g e}\]

used by the legacy GVA tail. Output shape: \((B, G, d_g)\).

The layer carries no trainable parameters. It is dtype-agnostic and forwards whatever keras.ops.einsum supports (float16 under mixed-precision, float32 default, float64 if ever requested).

__init__(**kwargs)

See Layer and Layer.__init__().

build(input_shape)

Validate the input shapes and finalize the build. No trainable weights are created.

Parameters:

input_shape – A list with exactly three shapes corresponding to attn_weights (B, K, G), w (B, K, G, d_g), and V_mod (B, K, G, d_g).

Raises:

src.model.deeplearn.deep_learning_exception.DeepLearningException – If the inputs do not match the expected rank triplet (3, 4, 4) or if the broadcast axes \(B\), \(K\), \(G\), \(d_g\) are inconsistent across the three operands.

call(inputs, training=False, mask=False)

Execute the fused GVA contraction.

Parameters:

inputs – A list [attn_weights, w, V_mod] of three tensors with shapes \((B, K, G)\), \((B, K, G, d_g)\), \((B, K, G, d_g)\).

Returns:

The aggregated context tensor of shape \((B, G, d_g)\).

compute_output_shape(input_shape)

Return the static output shape \((B, G, d_g)\).

Parameters:

input_shape – The list of three input shapes as documented in build().

Returns:

The output shape.

Return type:

tuple

get_config()

Return the dictionary necessary to deserialize the layer.

classmethod from_config(config)

Deserialize a TORFGVALayer from its configuration.