# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.ops import packed_to_padded
from pytorch3d.renderer.implicit.utils import ray_bundle_variables_to_ray_points
[docs]
class EvaluationMode(Enum):
TRAINING = "training"
EVALUATION = "evaluation"
[docs]
class RenderSamplingMode(Enum):
MASK_SAMPLE = "mask_sample"
FULL_GRID = "full_grid"
[docs]
class ImplicitronRayBundle:
"""
Parametrizes points along projection rays by storing ray `origins`,
`directions` vectors and `lengths` at which the ray-points are sampled.
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
Note that `directions` don't have to be normalized; they define unit vectors
in the respective 1D coordinate systems; see documentation for
:func:`ray_bundle_to_ray_points` for the conversion formula.
Ray bundle may represent rays from multiple cameras. In that case, cameras
are stored in the packed form (i.e. rays from the same camera are stored in
the consecutive elements). The following indices will be set:
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`.
Attributes:
origins: A tensor of shape `(..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(..., 3)` containing the direction
vectors of sampling rays in world coords. They don't have to be normalized;
they define unit vectors in the respective 1D coordinate systems; see
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
lengths: A tensor of shape `(..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
camera_ids: An optional tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of unique sampled cameras.
camera_counts: An optional tensor of shape (N, ) indicates how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==total_number_of_rays`.
bins: An optional tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. In this case
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
pixel_radii_2d: An optional tensor of shape `(..., 1)`
base radii of the conical frustums.
Raises:
ValueError: If either bins or lengths are not provided.
ValueError: If bins is provided and the last dim is inferior or equal to 1.
"""
def __init__(
self,
origins: torch.Tensor,
directions: torch.Tensor,
lengths: Optional[torch.Tensor],
xys: torch.Tensor,
camera_ids: Optional[torch.LongTensor] = None,
camera_counts: Optional[torch.LongTensor] = None,
bins: Optional[torch.Tensor] = None,
pixel_radii_2d: Optional[torch.Tensor] = None,
):
if bins is not None and bins.shape[-1] <= 1:
raise ValueError(
"The last dim of bins must be at least superior or equal to 2."
)
if bins is None and lengths is None:
raise ValueError(
"Please set either bins or lengths to initialize an ImplicitronRayBundle."
)
self.origins = origins
self.directions = directions
self._lengths = lengths if bins is None else None
self.xys = xys
self.bins = bins
self.pixel_radii_2d = pixel_radii_2d
self.camera_ids = camera_ids
self.camera_counts = camera_counts
@property
def lengths(self) -> torch.Tensor:
if self.bins is not None:
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
# pyre-ignore
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
return self._lengths
@lengths.setter
def lengths(self, value):
if self.bins is not None:
raise ValueError(
"If the bins attribute is not None you cannot set the lengths attribute."
)
else:
self._lengths = value
[docs]
def float_(self) -> None:
"""Moves the tensors to float dtype in place
(helpful for mixed-precision tensors).
"""
self.origins = self.origins.float()
self.directions = self.directions.float()
self._lengths = self._lengths.float() if self._lengths is not None else None
self.xys = self.xys.float()
self.bins = self.bins.float() if self.bins is not None else None
self.pixel_radii_2d = (
self.pixel_radii_2d.float() if self.pixel_radii_2d is not None else None
)
[docs]
def is_packed(self) -> bool:
"""
Returns whether the ImplicitronRayBundle carries data in packed state
"""
return self.camera_ids is not None and self.camera_counts is not None
[docs]
def get_padded_xys(self) -> Tuple[torch.Tensor, torch.LongTensor, int]:
"""
For a packed ray bundle, returns padded rays. Assumes the input bundle is packed
(i.e. `camera_ids` and `camera_counts` are set).
Returns:
- xys: Tensor of shape (N, max_size, ...) containing the padded
representation of the pixel coordinated;
where max_size is max of `camera_counts`. The values for camera id `i`
will be copied to `xys[i, :]`, with zeros padding out the extra inputs.
- first_idxs: cumulative sum of `camera_counts` defininf the boundaries
between cameras in the packed representation
- num_inputs: the number of cameras in the bundle.
"""
if not self.is_packed():
raise ValueError("get_padded_xys can be called only on a packed bundle")
camera_counts = self.camera_counts
assert camera_counts is not None
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
)
num_inputs = camera_counts.sum().item()
max_size = torch.max(camera_counts).item()
xys = packed_to_padded(self.xys, first_idxs, max_size)
# pyre-ignore [7] pytorch typeshed inaccuracy
return xys, first_idxs, num_inputs
[docs]
@dataclass
class RendererOutput:
"""
A structure for storing the output of a renderer.
Args:
features: rendered features (usually RGB colors), (B, ..., C) tensor.
depth: rendered ray-termination depth map, in NDC coordinates, (B, ..., 1) tensor.
mask: rendered object mask, values in [0, 1], (B, ..., 1) tensor.
prev_stage: for multi-pass renderers (e.g. in NeRF),
a reference to the output of the previous stage.
normals: surface normals, for renderers that estimate them; (B, ..., 3) tensor.
points: ray-termination points in the world coordinates, (B, ..., 3) tensor.
aux: dict for implementation-specific renderer outputs.
"""
features: torch.Tensor
depths: torch.Tensor
masks: torch.Tensor
prev_stage: Optional[RendererOutput] = None
normals: Optional[torch.Tensor] = None
points: Optional[torch.Tensor] = None # TODO: redundant with depths
weights: Optional[torch.Tensor] = None
aux: Dict[str, Any] = field(default_factory=lambda: {})
[docs]
class ImplicitFunctionWrapper(torch.nn.Module):
def __init__(self, fn: torch.nn.Module):
super().__init__()
self._fn = fn
self.bound_args = {}
[docs]
def bind_args(self, **bound_args):
self.bound_args = bound_args
self._fn.on_bind_args()
[docs]
def unbind_args(self):
self.bound_args = {}
[docs]
def forward(self, *args, **kwargs):
return self._fn(*args, **{**kwargs, **self.bound_args})
[docs]
class BaseRenderer(ABC, ReplaceableBase):
"""
Base class for all Renderer implementations.
"""
[docs]
def requires_object_mask(self) -> bool:
"""
Whether `forward` needs the object_mask.
"""
return False
[docs]
@abstractmethod
def forward(
self,
ray_bundle: ImplicitronRayBundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> RendererOutput:
"""
Each Renderer should implement its own forward function
that returns an instance of RendererOutput.
Args:
ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape (minibatch, ..., 3) denoting
the origins of the rendering rays.
directions: A tensor of shape (minibatch, ..., 3)
containing the direction vectors of rendering rays.
lengths: A tensor of shape
(minibatch, ..., num_points_per_ray)containing the
lengths at which the ray points are sampled.
The coordinates of the points on the rays are thus computed
as `origins + lengths * directions`.
xys: A tensor of shape
(minibatch, ..., 2) containing the
xy locations of each ray's pixel in the NDC screen space.
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==minibatch`
implicit_functions: List of ImplicitFunctionWrappers which define the
implicit function methods to be used. Most Renderers only allow
a single implicit function. Currently, only the
MultiPassEmissionAbsorptionRenderer allows specifying mulitple
values in the list.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
**kwargs: In addition to the name args, custom keyword args can be specified.
For example in the SignedDistanceFunctionRenderer, an object_mask is
required which needs to be passed via the kwargs.
Returns:
instance of RendererOutput
"""
pass
[docs]
def compute_3d_diagonal_covariance_gaussian(
rays_directions: torch.Tensor,
rays_dir_variance: torch.Tensor,
radii_variance: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Transform the variances (rays_dir_variance, radii_variance) of the gaussians from
the coordinate frame of the conical frustum to 3D world coordinates.
It follows the equation 16 of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_
Args:
rays_directions: A tensor of shape `(..., 3)`
rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum with respect to the rays direction.
radii_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum with respect to its radius.
eps: a small number to prevent division by zero.
Returns:
A tensor of shape `(..., num_intervals, 3)` containing the diagonal
of the covariance matrix.
"""
d_outer_diag = torch.pow(rays_directions, 2)
dir_mag_sq = torch.clamp(torch.sum(d_outer_diag, dim=-1, keepdim=True), min=eps)
null_outer_diag = 1 - d_outer_diag / dir_mag_sq
ray_dir_cov_diag = rays_dir_variance[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = radii_variance[..., None] * null_outer_diag[..., None, :]
return ray_dir_cov_diag + xy_cov_diag
[docs]
def approximate_conical_frustum_as_gaussians(
bins: torch.Tensor, radii: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Approximates a conical frustum as two Gaussian distributions.
The Gaussian distributions are characterized by
three values:
- rays_dir_mean: mean along the rays direction
(defined as t in the parametric representation of a cone).
- rays_dir_variance: the variance of the conical frustum along the rays direction.
- radii_variance: variance of the conical frustum with respect to its radius.
The computation is stable and follows equation 7
of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
For more information on how the mean and variances are computed
refers to the appendix of the paper.
Args:
bins: A tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled.
`bin[..., t]` and `bin[..., t+1]` represent respectively
the left and right coordinates of the interval.
t0: A tensor of shape `(..., num_points_per_ray)`
containing the left coordinates of the intervals
on which the rays are sampled.
t1: A tensor of shape `(..., num_points_per_ray)`
containing the rights coordinates of the intervals
on which the rays are sampled.
radii: A tensor of shape `(..., 1)`
base radii of the conical frustums.
Returns:
rays_dir_mean: A tensor of shape `(..., num_intervals)` representing
the mean along the rays direction
(t in the parametric represention of the cone)
rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum along the rays
(t in the parametric represention of the cone).
radii_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum with respect to its radius.
"""
t_mu = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
t_delta = torch.diff(bins, dim=-1) / 2
t_mu_pow2 = torch.pow(t_mu, 2)
t_delta_pow2 = torch.pow(t_delta, 2)
t_delta_pow4 = torch.pow(t_delta, 4)
den = 3 * t_mu_pow2 + t_delta_pow2
# mean along the rays direction
rays_dir_mean = t_mu + 2 * t_mu * t_delta_pow2 / den
# Variance of the conical frustum with along the rays directions
rays_dir_variance = t_delta_pow2 / 3 - (4 / 15) * (
t_delta_pow4 * (12 * t_mu_pow2 - t_delta_pow2) / torch.pow(den, 2)
)
# Variance of the conical frustum with respect to its radius
radii_variance = torch.pow(radii, 2) * (
t_mu_pow2 / 4 + (5 / 12) * t_delta_pow2 - 4 / 15 * (t_delta_pow4) / den
)
return rays_dir_mean, rays_dir_variance, radii_variance
[docs]
def conical_frustum_to_gaussian(
ray_bundle: ImplicitronRayBundle,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Approximate a conical frustum following a ray bundle as a Gaussian.
Args:
ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
origins: A tensor of shape `(..., 3)`
directions: A tensor of shape `(..., 3)`
lengths: A tensor of shape `(..., num_points_per_ray)`
bins: A tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. .
pixel_radii_2d: A tensor of shape `(..., 1)`
base radii of the conical frustums.
Returns:
means: A tensor of shape `(..., num_points_per_ray - 1, 3)`
representing the means of the Gaussians
approximating the conical frustums.
diag_covariances: A tensor of shape `(...,num_points_per_ray -1, 3)`
representing the diagonal covariance matrices of our Gaussians.
"""
if ray_bundle.pixel_radii_2d is None or ray_bundle.bins is None:
raise ValueError(
"RayBundle pixel_radii_2d or bins have not been provided."
" Look at pytorch3d.renderer.implicit.renderer.ray_sampler::"
"AbstractMaskRaySampler to see how to compute them. Have you forgot to set"
"`cast_ray_bundle_as_cone` to True?"
)
(
rays_dir_mean,
rays_dir_variance,
radii_variance,
) = approximate_conical_frustum_as_gaussians(
ray_bundle.bins,
ray_bundle.pixel_radii_2d,
)
means = ray_bundle_variables_to_ray_points(
ray_bundle.origins, ray_bundle.directions, rays_dir_mean
)
diag_covariances = compute_3d_diagonal_covariance_gaussian(
ray_bundle.directions, rays_dir_variance, radii_variance
)
return means, diag_covariances