src.model.deeplearn.layer.torf_slice_center_layer
Classes
|
- class src.model.deeplearn.layer.torf_slice_center_layer.TORFSliceCenterLayer(*args, **kwargs)
- Author:
Alberto M. Esmoris Pena
Extract the center (first) element along axis 1, keeping the dimension so the result can be broadcast-subtracted or broadcast-multiplied against the full tensor.
Works on tensors of any rank \(\geq 3\). The slice
inputs[:, 0:1, :]preserves all trailing dimensions.3D input \((B, K, D)\) produces \((B, 1, D)\). Used in the PE computation to extract the center point’s coordinates.
4D input \((B, K, G, d_g)\) produces \((B, 1, G, d_g)\). Used in Grouped Vector Attention to extract the center point’s query vector per group.
- __init__(**kwargs)
See
LayerandLayer.__init__().
- call(inputs, training=False, mask=False)
Slice the first element along axis 1.
- Parameters:
inputs – Tensor of shape (B, K, …) with rank >= 3.
- Returns:
Tensor of shape (B, 1, …) with trailing dimensions preserved.
- get_config()
Return necessary data to deserialize the layer.
- classmethod from_config(config)
Use given config data to deserialize the layer.