src.model.deeplearn.layer.torf_point_wise_mask_layer
Classes
|
- class src.model.deeplearn.layer.torf_point_wise_mask_layer.TORFPointWiseMaskLayer(*args, **kwargs)
- Author:
Alberto M. Esmoris Pena
Zero out invalid positions in a per-point output tensor using a boolean (or float) mask.
Given tensors
[X, M]whereXhas shape \((B, K, C)\) andMhas shape \((B, K)\):\[\begin{split}f(\mathbf{X}, \mathbf{M})_{b,k,c} = \begin{cases} X_{b,k,c} & \text{if } M_{b,k} \neq 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]The mask is expanded along the last axis via
expand_dimsto broadcast correctly with the class dimension, avoidingReshapewhich can cause XLA shape mismatches during graph compilation.- __init__(**kwargs)
See
LayerandLayer.__init__().
- call(inputs, training=False, mask=False)
Apply per-point masking.
- Parameters:
inputs – List of two tensors
[X, M].Xhas shape(B, K, C)andMhas shape(B, K)(float, 0/1).- Returns:
Masked tensor of shape
(B, K, C).
- get_config()
Return necessary data to deserialize the layer.
- classmethod from_config(config)
Use given config data to deserialize the layer.