src.model.deeplearn.layer.torf_slice_center_layer

Classes

TORFSliceCenterLayer(*args, **kwargs)

class src.model.deeplearn.layer.torf_slice_center_layer.TORFSliceCenterLayer(*args, **kwargs)
Author:

Alberto M. Esmoris Pena

Extract the center (first) element along axis 1, keeping the dimension so the result can be broadcast-subtracted or broadcast-multiplied against the full tensor.

Works on tensors of any rank \(\geq 3\). The slice inputs[:, 0:1, :] preserves all trailing dimensions.

  • 3D input \((B, K, D)\) produces \((B, 1, D)\). Used in the PE computation to extract the center point’s coordinates.

  • 4D input \((B, K, G, d_g)\) produces \((B, 1, G, d_g)\). Used in Grouped Vector Attention to extract the center point’s query vector per group.

__init__(**kwargs)

See Layer and Layer.__init__().

call(inputs, training=False, mask=False)

Slice the first element along axis 1.

Parameters:

inputs – Tensor of shape (B, K, …) with rank >= 3.

Returns:

Tensor of shape (B, 1, …) with trailing dimensions preserved.

get_config()

Return necessary data to deserialize the layer.

classmethod from_config(config)

Use given config data to deserialize the layer.