Source code for pytorch3d.implicitron.models.implicit_function.neural_radiance_field

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from typing import Optional, Tuple

import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.models.renderer.base import (
    conical_frustum_to_gaussian,
    ImplicitronRayBundle,
)
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points

from .base import ImplicitFunctionBase

from .decoding_functions import (  # noqa
    _xavier_init,
    MLPWithInputSkips,
    TransformerWithInputSkips,
)
from .utils import create_embeddings_for_implicit_function


logger = logging.getLogger(__name__)


[docs] class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): n_harmonic_functions_xyz: int = 10 n_harmonic_functions_dir: int = 4 n_hidden_neurons_dir: int = 128 latent_dim: int = 0 input_xyz: bool = True xyz_ray_dir_in_camera_coords: bool = False color_dim: int = 3 use_integrated_positional_encoding: bool = False """ Args: n_harmonic_functions_xyz: The number of harmonic functions used to form the harmonic embedding of 3D point locations. n_harmonic_functions_dir: The number of harmonic functions used to form the harmonic embedding of the ray directions. n_hidden_neurons_xyz: The number of hidden units in the fully connected layers of the MLP that accepts the 3D point locations and outputs the occupancy field with the intermediate features. n_hidden_neurons_dir: The number of hidden units in the fully connected layers of the MLP that accepts the intermediate features and ray directions and outputs the radiance field (per-point colors). n_layers_xyz: The number of layers of the MLP that outputs the occupancy field. append_xyz: The list of indices of the skip layers of the occupancy MLP. use_integrated_positional_encoding: If True, use integrated positional enoding as defined in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. If False, use the classical harmonic embedding defined in `NeRF <https://arxiv.org/abs/2003.08934>`_. """ def __post_init__(self): # The harmonic embedding layer converts input 3D coordinates # to a representation that is more suitable for # processing with a deep neural network. self.harmonic_embedding_xyz = HarmonicEmbedding( self.n_harmonic_functions_xyz, append_input=True ) self.harmonic_embedding_dir = HarmonicEmbedding( self.n_harmonic_functions_dir, append_input=True ) if not self.input_xyz and self.latent_dim <= 0: raise ValueError("The latent dimension has to be > 0 if xyz is not input!") embedding_dim_dir = self.harmonic_embedding_dir.get_output_dim() self.xyz_encoder = self._construct_xyz_encoder( input_dim=self.get_xyz_embedding_dim() ) self.intermediate_linear = torch.nn.Linear( self.n_hidden_neurons_xyz, self.n_hidden_neurons_xyz ) _xavier_init(self.intermediate_linear) self.density_layer = torch.nn.Linear(self.n_hidden_neurons_xyz, 1) _xavier_init(self.density_layer) # Zero the bias of the density layer to avoid # a completely transparent initialization. self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough self.color_layer = torch.nn.Sequential( LinearWithRepeat( self.n_hidden_neurons_xyz + embedding_dim_dir, self.n_hidden_neurons_dir ), torch.nn.ReLU(True), torch.nn.Linear(self.n_hidden_neurons_dir, self.color_dim), torch.nn.Sigmoid(), )
[docs] def get_xyz_embedding_dim(self): return ( self.harmonic_embedding_xyz.get_output_dim() * int(self.input_xyz) + self.latent_dim )
def _construct_xyz_encoder(self, input_dim: int): raise NotImplementedError() def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor): """ This function takes per-point `features` predicted by `self.xyz_encoder` and evaluates the color model in order to attach to each point a 3D vector of its RGB color. """ # Normalize the ray_directions to unit l2 norm. rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) # Obtain the harmonic embedding of the normalized ray directions. rays_embedding = self.harmonic_embedding_dir(rays_directions_normed) return self.color_layer((self.intermediate_linear(features), rays_embedding))
[docs] @staticmethod def allows_multiple_passes() -> bool: """ Returns True as this implicit function allows multiple passes. Overridden from ImplicitFunctionBase. """ return True
[docs] def forward( self, *, ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, **kwargs, ): """ The forward function accepts the parametrizations of 3D points sampled along projection rays. The forward pass is responsible for attaching a 3D vector and a 1D scalar representing the point's RGB color and opacity respectively. Args: ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` containing the direction vectors of sampling rays in world coords. lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` containing the lengths at which the rays are sampled. bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)` containing the bins at which the rays are sampled. In this case lengths is equal to the midpoints of bins. fun_viewpool: an optional callback with the signature fun_fiewpool(points) -> pooled_features where points is a [N_TGT x N x 3] tensor of world coords, and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor of the features pooled from the context images. Returns: rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` denoting the opacitiy of each ray point. rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` denoting the color of each ray point. Raises: ValueError: If `use_integrated_positional_encoding` is True and `ray_bundle.bins` is None. """ if self.use_integrated_positional_encoding and ray_bundle.bins is None: raise ValueError( "When use_integrated_positional_encoding is True, ray_bundle.bins must be set." "Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?" ) rays_points_world, diag_cov = ( conical_frustum_to_gaussian(ray_bundle) if self.use_integrated_positional_encoding else (ray_bundle_to_ray_points(ray_bundle), None) # pyre-ignore ) # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] embeds = create_embeddings_for_implicit_function( xyz_world=rays_points_world, # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. xyz_embedding_function=( self.harmonic_embedding_xyz if self.input_xyz else None ), global_code=global_code, fun_viewpool=fun_viewpool, xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords, camera=camera, diag_cov=diag_cov, ) # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3] features = self.xyz_encoder(embeds) # features.shape = [minibatch x ... x self.n_hidden_neurons_xyz] # NNs operate on the flattenned rays; reshaping to the correct spatial size # TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape features = features.reshape(*rays_points_world.shape[:-1], -1) raw_densities = self.density_layer(features) # raw_densities.shape = [minibatch x ... x 1] in [0-1] if self.xyz_ray_dir_in_camera_coords: if camera is None: raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") directions = ray_bundle.directions @ camera.R else: directions = ray_bundle.directions rays_colors = self._get_colors(features, directions) # rays_colors.shape = [minibatch x ... x 3] in [0-1] return raw_densities, rays_colors, {}
[docs] @registry.register class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase): transformer_dim_down_factor: float = 1.0 n_hidden_neurons_xyz: int = 256 n_layers_xyz: int = 8 append_xyz: Tuple[int, ...] = (5,) def _construct_xyz_encoder(self, input_dim: int): expand_args_fields(MLPWithInputSkips) return MLPWithInputSkips( self.n_layers_xyz, input_dim, self.n_hidden_neurons_xyz, input_dim, self.n_hidden_neurons_xyz, input_skips=self.append_xyz, )
[docs] @registry.register class NeRFormerImplicitFunction(NeuralRadianceFieldBase): transformer_dim_down_factor: float = 2.0 n_hidden_neurons_xyz: int = 80 n_layers_xyz: int = 2 append_xyz: Tuple[int, ...] = (1,) def _construct_xyz_encoder(self, input_dim: int): return TransformerWithInputSkips( self.n_layers_xyz, input_dim, self.n_hidden_neurons_xyz, input_dim, self.n_hidden_neurons_xyz, input_skips=self.append_xyz, dim_down_factor=self.transformer_dim_down_factor, )
[docs] @staticmethod def requires_pooling_without_aggregation() -> bool: """ Returns True as this implicit function needs pooling without aggregation. Overridden from ImplicitFunctionBase. """ return True