Source code for pytorch3d.implicitron.models.renderer.ray_point_refiner

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import copy

import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields

from pytorch3d.renderer.implicit.sample_pdf import sample_pdf


[docs] @expand_args_fields class RayPointRefiner(Configurable, torch.nn.Module): """ Implements the importance sampling of points along rays. The input is a `RayBundle` object with a `ray_weights` tensor which specifies the probabilities of sampling a point along each ray. This raysampler is used for the fine rendering pass of NeRF. As such, the forward pass accepts the RayBundle output by the raysampling of the coarse rendering pass. Hence, it does not take cameras as input. Args: n_pts_per_ray: The number of points to sample along each ray. random_sampling: If `False`, returns equispaced percentiles of the distribution defined by the input weights, otherwise performs sampling from that distribution. add_input_samples: Concatenates and returns the sampled values together with the input samples. blurpool_weights: Use blurpool defined in [1], on the input weights. sample_pdf_eps: A constant preventing division by zero in case empty bins are present. References: [1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields." ICCV 2021. """ # pyre-fixme[13]: Attribute `n_pts_per_ray` is never initialized. n_pts_per_ray: int # pyre-fixme[13]: Attribute `random_sampling` is never initialized. random_sampling: bool add_input_samples: bool = True blurpool_weights: bool = False sample_pdf_eps: float = 1e-5
[docs] def forward( self, input_ray_bundle: ImplicitronRayBundle, ray_weights: torch.Tensor, blurpool_weights: bool = False, sample_pdf_padding: float = 1e-5, **kwargs, ) -> ImplicitronRayBundle: """ Args: input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the source rays for sampling of the probability distribution. ray_weights: A tensor of shape `(..., input_ray_bundle.lengths.shape[-1])` with non-negative elements defining the probability distribution to sample ray points from. blurpool_weights: Use blurpool defined in [1], on the input weights. sample_pdf_padding: A constant preventing division by zero in case empty bins are present. Returns: ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray points together with `n_pts_per_ray` additionally sampled points per ray. For each ray, the lengths are sorted. References: [1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields." ICCV 2021. """ with torch.no_grad(): if self.blurpool_weights: ray_weights = apply_blurpool_on_weights(ray_weights) n_pts_per_ray = self.n_pts_per_ray ray_weights = ray_weights.view(-1, ray_weights.shape[-1]) if input_ray_bundle.bins is None: z_vals: torch.Tensor = input_ray_bundle.lengths ray_weights = ray_weights[..., 1:-1] bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5) else: z_vals = input_ray_bundle.bins n_pts_per_ray += 1 bins = z_vals z_samples = sample_pdf( bins.view(-1, bins.shape[-1]), ray_weights, n_pts_per_ray, det=not self.random_sampling, eps=self.sample_pdf_eps, ).view(*z_vals.shape[:-1], n_pts_per_ray) if self.add_input_samples: z_vals = torch.cat((z_vals, z_samples), dim=-1) else: z_vals = z_samples # Resort by depth. z_vals, _ = torch.sort(z_vals, dim=-1) ray_bundle = copy.copy(input_ray_bundle) if input_ray_bundle.bins is None: ray_bundle.lengths = z_vals else: ray_bundle.bins = z_vals return ray_bundle
[docs] def apply_blurpool_on_weights(weights) -> torch.Tensor: """ Filter weights with a 2-tap max filters followed by a 2-tap blur filter, which produces a wide and smooth upper envelope on the weights. Args: weights: Tensor of shape `(..., dim)` Returns: blured_weights: Tensor of shape `(..., dim)` """ weights_pad = torch.concatenate( [ weights[..., :1], weights, weights[..., -1:], ], dim=-1, ) weights_max = torch.nn.functional.max_pool1d( weights_pad.flatten(end_dim=-2), 2, stride=1 ) return torch.lerp(weights_max[..., :-1], weights_max[..., 1:], 0.5).reshape_as( weights )