src.model.deeplearn.layer.torf_gva_layer
Classes
|
- class src.model.deeplearn.layer.torf_gva_layer.TORFGVALayer(*args, **kwargs)
- Author:
Alberto M. Esmoris Pena
Mega-fused tail of a Grouped Vector Attention (GVA) block. A single einsum contracts the attention weights, the weight-MLP output, and the modulated values into the per-group aggregated context, replacing the Reshape + Multiply + Multiply + ReduceSum quartet that otherwise materializes a full \((B, K, G, d_g)\) intermediate in VRAM before the \(K\)-reduction.
Given:
\(\boldsymbol{\alpha} \in \mathbb{R}^{B \times K \times G}\) — softmax attention weights per (sample, neighbor, group).
\(\mathbf{w} \in \mathbb{R}^{B \times K \times G \times d_g}\) — weight-MLP output per (sample, neighbor, group, sub-channel).
\(\mathbf{V}_{\text{mod}} \in \mathbb{R}^{B \times K \times G \times d_g}\) — modulated values \(\mathbf{V} + \mathbf{PE}\).
The layer computes:
\[\mathbf{A}_{b g e} = \sum_{k=1}^{K} \alpha_{b k g} \, w_{b k g e} \, (V_{\text{mod}})_{b k g e}\]which is identical (up to float reduction reassociation) to the sequence:
\[\mathbf{A} = \sum_{k} (\alpha_{b k g} \cdot w_{b k g e}) \odot (V_{\text{mod}})_{b k g e}\]used by the legacy GVA tail. Output shape: \((B, G, d_g)\).
The layer carries no trainable parameters. It is dtype-agnostic and forwards whatever
keras.ops.einsumsupports (float16 under mixed-precision, float32 default, float64 if ever requested).- __init__(**kwargs)
See
LayerandLayer.__init__().
- build(input_shape)
Validate the input shapes and finalize the build. No trainable weights are created.
- Parameters:
input_shape – A list with exactly three shapes corresponding to
attn_weights(B, K, G),w(B, K, G, d_g), andV_mod(B, K, G, d_g).- Raises:
src.model.deeplearn.deep_learning_exception.DeepLearningException – If the inputs do not match the expected rank triplet (3, 4, 4) or if the broadcast axes \(B\), \(K\), \(G\), \(d_g\) are inconsistent across the three operands.
- call(inputs, training=False, mask=False)
Execute the fused GVA contraction.
- Parameters:
inputs – A list
[attn_weights, w, V_mod]of three tensors with shapes \((B, K, G)\), \((B, K, G, d_g)\), \((B, K, G, d_g)\).- Returns:
The aggregated context tensor of shape \((B, G, d_g)\).
- compute_output_shape(input_shape)
Return the static output shape \((B, G, d_g)\).
- Parameters:
input_shape – The list of three input shapes as documented in
build().- Returns:
The output shape.
- Return type:
tuple
- get_config()
Return the dictionary necessary to deserialize the layer.
- classmethod from_config(config)
Deserialize a
TORFGVALayerfrom its configuration.