Source code for pytorch3d.renderer.implicit.harmonic_embedding

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Optional

import torch


[docs] class HarmonicEmbedding(torch.nn.Module):
[docs] def __init__( self, n_harmonic_functions: int = 6, omega_0: float = 1.0, logspace: bool = True, append_input: bool = True, ) -> None: """ The harmonic embedding layer supports the classical Nerf positional encoding described in `NeRF <https://arxiv.org/abs/2003.08934>`_ and the integrated position encoding in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. During the inference you can provide the extra argument `diag_cov`. If `diag_cov is None`, it converts rays parametrized with a `ray_bundle` to 3D points by extending each ray according to the corresponding length. Then it converts each feature (i.e. vector along the last dimension) in `x` into a series of harmonic features `embedding`, where for each i in range(dim) the following are present in embedding[...]:: [ sin(f_1*x[..., i]), sin(f_2*x[..., i]), ... sin(f_N * x[..., i]), cos(f_1*x[..., i]), cos(f_2*x[..., i]), ... cos(f_N * x[..., i]), x[..., i], # only present if append_input is True. ] where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar denoting the i-th frequency of the harmonic embedding. If `diag_cov is not None`, it approximates conical frustums following a ray bundle as gaussians, defined by x, the means of the gaussians and diag_cov, the diagonal covariances. Then it converts each gaussian into a series of harmonic features `embedding`, where for each i in range(dim) the following are present in embedding[...]:: [ sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), ... sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, ... cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), x[..., i], # only present if append_input is True. ] where N equals `n_harmonic_functions-1`, and f_i is a scalar denoting the i-th frequency of the harmonic embedding. If `logspace==True`, the frequencies `[f_1, ..., f_N]` are powers of 2: `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` If `logspace==False`, frequencies are linearly spaced between `1.0` and `2**(n_harmonic_functions-1)`: `f_1, ..., f_N = torch.linspace( 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions )` Note that `x` is also premultiplied by the base frequency `omega_0` before evaluating the harmonic functions. Args: n_harmonic_functions: int, number of harmonic features omega_0: float, base frequency logspace: bool, Whether to space the frequencies in logspace or linear space append_input: bool, whether to concat the original input to the harmonic embedding. If true the output is of the form (embed.sin(), embed.cos(), x) """ super().__init__() if logspace: frequencies = 2.0 ** torch.arange( n_harmonic_functions, dtype=torch.float32, ) else: frequencies = torch.linspace( 1.0, 2.0 ** (n_harmonic_functions - 1), n_harmonic_functions, dtype=torch.float32, ) self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) self.register_buffer( "_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False ) self.append_input = append_input
[docs] def forward( self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: """ Args: x: tensor of shape [..., dim] diag_cov: An optional tensor of shape `(..., dim)` representing the diagonal covariance matrices of our Gaussians, joined with x as means of the Gaussians. Returns: embedding: a harmonic embedding of `x` of shape [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] """ # [..., dim, n_harmonic_functions] embed = x[..., None] * self._frequencies # [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions] embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] # Use the trig identity cos(x) = sin(x + pi/2) # and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)). embed = embed.sin() if diag_cov is not None: x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) exp_var = torch.exp(-0.5 * x_var) # [..., 2, dim, n_harmonic_functions] embed = embed * exp_var[..., None, :, :] embed = embed.reshape(*x.shape[:-1], -1) if self.append_input: return torch.cat([embed, x], dim=-1) return embed
[docs] @staticmethod def get_output_dim_static( input_dims: int, n_harmonic_functions: int, append_input: bool, ) -> int: """ Utility to help predict the shape of the output of `forward`. Args: input_dims: length of the last dimension of the input tensor n_harmonic_functions: number of embedding frequencies append_input: whether or not to concat the original input to the harmonic embedding Returns: int: the length of the last dimension of the output tensor """ return input_dims * (2 * n_harmonic_functions + int(append_input))
[docs] def get_output_dim(self, input_dims: int = 3) -> int: """ Same as above. The default for input_dims is 3 for 3D applications which use harmonic embedding for positional encoding, so the input might be xyz. """ return self.get_output_dim_static( input_dims, len(self._frequencies), self.append_input )