Source code for pytorch3d.implicitron.models.implicit_function.decoding_functions

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

"""
This file contains
    - modules which get used by ImplicitFunction objects for decoding an embedding defined in
        space, e.g. to color or opacity.
    - DecoderFunctionBase and its subclasses, which wrap some of those modules, providing
        some such modules as an extension point which an ImplicitFunction object could use.
"""

import logging
from dataclasses import field

from enum import Enum
from typing import Dict, Optional, Tuple

import torch

from omegaconf import DictConfig

from pytorch3d.implicitron.tools.config import (
    Configurable,
    registry,
    ReplaceableBase,
    run_auto_creation,
)

logger = logging.getLogger(__name__)


[docs] class DecoderActivation(Enum): RELU = "relu" SOFTPLUS = "softplus" SIGMOID = "sigmoid" IDENTITY = "identity"
[docs] class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): """ Decoding function is a torch.nn.Module which takes the embedding of a location in space and transforms it into the required quantity (for example density and color). """
[docs] def forward( self, features: torch.Tensor, z: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: features (torch.Tensor): tensor of shape (batch, ..., num_in_features) z: optional tensor to append to parts of the decoding function Returns: decoded_features (torch.Tensor) : tensor of shape (batch, ..., num_out_features) """ raise NotImplementedError()
[docs] @registry.register class ElementwiseDecoder(DecoderFunctionBase): """ Decoding function which scales the input, adds shift and then applies `relu`, `softplus`, `sigmoid` or nothing on its input: `result = operation(input * scale + shift)` Members: scale: a scalar with which input is multiplied before being shifted. Defaults to 1. shift: a scalar which is added to the scaled input before performing the operation. Defaults to 0. operation: which operation to perform on the transformed input. Options are: `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`. Defaults to `IDENTITY`. """ scale: float = 1 shift: float = 0 operation: DecoderActivation = DecoderActivation.IDENTITY def __post_init__(self): if self.operation not in [ DecoderActivation.RELU, DecoderActivation.SOFTPLUS, DecoderActivation.SIGMOID, DecoderActivation.IDENTITY, ]: raise ValueError( "`operation` can only be `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`." )
[docs] def forward( self, features: torch.Tensor, z: Optional[torch.Tensor] = None ) -> torch.Tensor: transfomed_input = features * self.scale + self.shift if self.operation == DecoderActivation.SOFTPLUS: return torch.nn.functional.softplus(transfomed_input) if self.operation == DecoderActivation.RELU: return torch.nn.functional.relu(transfomed_input) if self.operation == DecoderActivation.SIGMOID: return torch.nn.functional.sigmoid(transfomed_input) return transfomed_input
[docs] class MLPWithInputSkips(Configurable, torch.nn.Module): """ Implements the multi-layer perceptron architecture of the Neural Radiance Field. As such, `MLPWithInputSkips` is a multi layer perceptron consisting of a sequence of linear layers with ReLU activations. Additionally, for a set of predefined layers `input_skips`, the forward pass appends a skip tensor `z` to the output of the preceding layer. Note that this follows the architecture described in the Supplementary Material (Fig. 7) of [1], for which keep the defaults for: - `last_layer_bias_init` to None - `last_activation` to "relu" - `use_xavier_init` to `true` If you want to use this as a part of the color prediction in TensoRF model set: - `last_layer_bias_init` to 0 - `last_activation` to "sigmoid" - `use_xavier_init` to `False` References: [1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng: NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, ECCV2020 Members: n_layers: The number of linear layers of the MLP. input_dim: The number of channels of the input tensor. output_dim: The number of channels of the output. skip_dim: The number of channels of the tensor `z` appended when evaluating the skip layers. hidden_dim: The number of hidden units of the MLP. input_skips: The list of layer indices at which we append the skip tensor `z`. last_layer_bias_init: If set then all the biases in the last layer are initialized to that value. last_activation: Which activation to use in the last layer. Options are: "relu", "softplus", "sigmoid" and "identity". Default is "relu". use_xavier_init: If True uses xavier init for all linear layer weights. Otherwise the default PyTorch initialization is used. Default True. """ n_layers: int = 8 input_dim: int = 39 output_dim: int = 256 skip_dim: int = 39 hidden_dim: int = 256 input_skips: Tuple[int, ...] = (5,) skip_affine_trans: bool = False last_layer_bias_init: Optional[float] = None last_activation: DecoderActivation = DecoderActivation.RELU use_xavier_init: bool = True def __post_init__(self): try: last_activation = { DecoderActivation.RELU: torch.nn.ReLU(True), DecoderActivation.SOFTPLUS: torch.nn.Softplus(), DecoderActivation.SIGMOID: torch.nn.Sigmoid(), DecoderActivation.IDENTITY: torch.nn.Identity(), }[self.last_activation] except KeyError as e: raise ValueError( "`last_activation` can only be `RELU`," " `SOFTPLUS`, `SIGMOID` or `IDENTITY`." ) from e layers = [] skip_affine_layers = [] for layeri in range(self.n_layers): dimin = self.hidden_dim if layeri > 0 else self.input_dim dimout = self.hidden_dim if layeri + 1 < self.n_layers else self.output_dim if layeri > 0 and layeri in self.input_skips: if self.skip_affine_trans: skip_affine_layers.append( self._make_affine_layer(self.skip_dim, self.hidden_dim) ) else: dimin = self.hidden_dim + self.skip_dim linear = torch.nn.Linear(dimin, dimout) if self.use_xavier_init: _xavier_init(linear) if layeri == self.n_layers - 1 and self.last_layer_bias_init is not None: torch.nn.init.constant_(linear.bias, self.last_layer_bias_init) layers.append( torch.nn.Sequential(linear, torch.nn.ReLU(True)) if not layeri + 1 < self.n_layers else torch.nn.Sequential(linear, last_activation) ) self.mlp = torch.nn.ModuleList(layers) if self.skip_affine_trans: self.skip_affines = torch.nn.ModuleList(skip_affine_layers) self._input_skips = set(self.input_skips) self._skip_affine_trans = self.skip_affine_trans def _make_affine_layer(self, input_dim, hidden_dim): l1 = torch.nn.Linear(input_dim, hidden_dim * 2) l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2) if self.use_xavier_init: _xavier_init(l1) _xavier_init(l2) return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2) def _apply_affine_layer(self, layer, x, z): mu_log_std = layer(z) mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1) std = torch.nn.functional.softplus(log_std) return (x - mu) * std
[docs] def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None): """ Args: x: The input tensor of shape `(..., input_dim)`. z: The input skip tensor of shape `(..., skip_dim)` which is appended to layers whose indices are specified by `input_skips`. Returns: y: The output tensor of shape `(..., output_dim)`. """ y = x if z is None: # if the skip tensor is None, we use `x` instead. z = x skipi = 0 for li, layer in enumerate(self.mlp): if li in self._input_skips: if self._skip_affine_trans: y = self._apply_affine_layer(self.skip_affines[skipi], y, z) else: y = torch.cat((y, z), dim=-1) skipi += 1 y = layer(y) return y
[docs] @registry.register # pyre-fixme[13]: Attribute `network` is never initialized. class MLPDecoder(DecoderFunctionBase): """ Decoding function which uses `MLPWithIputSkips` to convert the embedding to output. The `input_dim` of the `network` is set from the value of `input_dim` member. Members: input_dim: dimension of input. param_groups: dictionary where keys are names of individual parameters or module members and values are the parameter group where the parameter/member will be sorted to. "self" key is used to denote the parameter group at the module level. Possible keys, including the "self" key do not have to be defined. By default all parameters are put into "default" parameter group and have the learning rate defined in the optimizer, it can be overridden at the: - module level with “self” key, all the parameters and child module's parameters will be put to that parameter group - member level, which is the same as if the `param_groups` in that member has key=“self” and value equal to that parameter group. This is useful if members do not have `param_groups`, for example torch.nn.Linear. - parameter level, parameter with the same name as the key will be put to that parameter group. network_args: configuration for MLPWithInputSkips """ input_dim: int = 3 param_groups: Dict[str, str] = field(default_factory=lambda: {}) network: MLPWithInputSkips def __post_init__(self): run_auto_creation(self)
[docs] def forward( self, features: torch.Tensor, z: Optional[torch.Tensor] = None ) -> torch.Tensor: return self.network(features, z)
[docs] @classmethod def network_tweak_args(cls, type, args: DictConfig) -> None: """ Special method to stop get_default_args exposing member's `input_dim`. """ args.pop("input_dim", None)
[docs] def create_network_impl(self, type, args: DictConfig) -> None: """ Set the input dimension of the `network` to the input dimension of the decoding function. """ self.network = MLPWithInputSkips(input_dim=self.input_dim, **args)
[docs] class TransformerWithInputSkips(torch.nn.Module):
[docs] def __init__( self, n_layers: int = 8, input_dim: int = 39, output_dim: int = 256, skip_dim: int = 39, hidden_dim: int = 64, input_skips: Tuple[int, ...] = (5,), dim_down_factor: float = 1, ): """ Args: n_layers: The number of linear layers of the MLP. input_dim: The number of channels of the input tensor. output_dim: The number of channels of the output. skip_dim: The number of channels of the tensor `z` appended when evaluating the skip layers. hidden_dim: The number of hidden units of the MLP. input_skips: The list of layer indices at which we append the skip tensor `z`. """ super().__init__() self.first = torch.nn.Linear(input_dim, hidden_dim) _xavier_init(self.first) self.skip_linear = torch.nn.ModuleList() layers_pool, layers_ray = [], [] dimout = 0 for layeri in range(n_layers): dimin = int(round(hidden_dim / (dim_down_factor**layeri))) dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1)))) logger.info(f"Tr: {dimin} -> {dimout}") for _i, l in enumerate((layers_pool, layers_ray)): l.append( TransformerEncoderLayer( d_model=[dimin, dimout][_i], nhead=4, dim_feedforward=hidden_dim, dropout=0.0, d_model_out=dimout, ) ) if layeri in input_skips: self.skip_linear.append(torch.nn.Linear(input_dim, dimin)) self.last = torch.nn.Linear(dimout, output_dim) _xavier_init(self.last) # pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as # `ModuleList`. self.layers_pool, self.layers_ray = ( torch.nn.ModuleList(layers_pool), torch.nn.ModuleList(layers_ray), ) self._input_skips = set(input_skips)
[docs] def forward( self, x: torch.Tensor, z: Optional[torch.Tensor] = None, ): """ Args: x: The input tensor of shape `(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`. z: The input skip tensor of shape `(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)` which is appended to layers whose indices are specified by `input_skips`. Returns: y: The output tensor of shape `(minibatch, 1, ..., n_ray_pts, input_dim)`. """ if z is None: # if the skip tensor is None, we use `x` instead. z = x y = self.first(x) B, n_pool, n_rays, n_pts, dim = y.shape # y_p in n_pool, n_pts, B x n_rays x dim y_p = y.permute(1, 3, 0, 2, 4) skipi = 0 dimh = dim for li, (layer_pool, layer_ray) in enumerate( zip(self.layers_pool, self.layers_ray) ): y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh) if li in self._input_skips: z_skip = self.skip_linear[skipi](z) y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape( n_pool, n_pts * B * n_rays, dimh ) skipi += 1 # n_pool x B*n_rays*n_pts x dim y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None) dimh = y_pool_attn.shape[-1] y_ray_attn = ( y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh) .permute(1, 0, 2, 3) .reshape(n_pts, n_pool * B * n_rays, dimh) ) # n_pts x n_pool*B*n_rays x dim y_ray_attn, ray_attn = layer_ray( y_ray_attn, src_key_padding_mask=None, ) y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3) y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4) W = torch.softmax(y[..., :1], dim=1) y = (y * W).sum(dim=1) y = self.last(y) return y
[docs] class TransformerEncoderLayer(torch.nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of intermediate layer, relu or gelu (default=relu). Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) """ def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1 ): super(TransformerEncoderLayer, self).__init__() self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = torch.nn.Linear(d_model, dim_feedforward) self.dropout = torch.nn.Dropout(dropout) d_model_out = d_model if d_model_out <= 0 else d_model_out self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out) self.norm1 = torch.nn.LayerNorm(d_model) self.norm2 = torch.nn.LayerNorm(d_model_out) self.dropout1 = torch.nn.Dropout(dropout) self.dropout2 = torch.nn.Dropout(dropout) self.activation = torch.nn.functional.relu
[docs] def forward(self, src, src_mask=None, src_key_padding_mask=None): r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ src2, attn = self.self_attn( src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask ) src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) d_out = src2.shape[-1] src = src[..., :d_out] + self.dropout2(src2)[..., :d_out] src = self.norm2(src) return src, attn
def _xavier_init(linear) -> None: """ Performs the Xavier weight initialization of the linear layer `linear`. """ torch.nn.init.xavier_uniform_(linear.weight.data)