src.model.deeplearn.layer.torf_point_wise_mask_layer

Classes

TORFPointWiseMaskLayer(*args, **kwargs)

class src.model.deeplearn.layer.torf_point_wise_mask_layer.TORFPointWiseMaskLayer(*args, **kwargs)
Author:

Alberto M. Esmoris Pena

Zero out invalid positions in a per-point output tensor using a boolean (or float) mask.

Given tensors [X, M] where X has shape \((B, K, C)\) and M has shape \((B, K)\):

\[\begin{split}f(\mathbf{X}, \mathbf{M})_{b,k,c} = \begin{cases} X_{b,k,c} & \text{if } M_{b,k} \neq 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]

The mask is expanded along the last axis via expand_dims to broadcast correctly with the class dimension, avoiding Reshape which can cause XLA shape mismatches during graph compilation.

__init__(**kwargs)

See Layer and Layer.__init__().

call(inputs, training=False, mask=False)

Apply per-point masking.

Parameters:

inputs – List of two tensors [X, M]. X has shape (B, K, C) and M has shape (B, K) (float, 0/1).

Returns:

Masked tensor of shape (B, K, C).

get_config()

Return necessary data to deserialize the layer.

classmethod from_config(config)

Use given config data to deserialize the layer.