# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools import camera_utils
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer import NDCMultinomialRaysampler
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle
from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode
[docs]
class RaySamplerBase(ReplaceableBase):
"""
Base class for ray samplers.
"""
[docs]
def forward(
self,
cameras: CamerasBase,
evaluation_mode: EvaluationMode,
mask: Optional[torch.Tensor] = None,
) -> ImplicitronRayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
evaluation_mode: one of `EvaluationMode.TRAINING` or
`EvaluationMode.EVALUATION` which determines the sampling mode
that is used.
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
Defines a non-negative mask of shape
`(batch_size, image_height, image_width)` where each per-pixel
value is proportional to the probability of sampling the
corresponding pixel's ray.
Returns:
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays.
"""
raise NotImplementedError()
[docs]
class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
"""
Samples a fixed number of points along rays which are in turn sampled for
each camera in a batch.
This class utilizes `NDCMultinomialRaysampler` which allows to either
randomly sample rays from an input foreground saliency mask
(`RenderSamplingMode.MASK_SAMPLE`), or on a rectangular image grid
(`RenderSamplingMode.FULL_GRID`). The sampling mode can be set separately
for training and evaluation by setting `self.sampling_mode_training`
and `self.sampling_mode_training` accordingly.
The class allows to adjust the sampling points along rays by overwriting the
`AbstractMaskRaySampler._get_min_max_depth_bounds` function which returns
the near/far planes (`min_depth`/`max_depth`) `NDCMultinomialRaysampler`.
Settings:
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
sampling_mode_training: The ray sampling mode for training. This should be a str
option from the RenderSamplingMode Enum
sampling_mode_evaluation: Same as above but for evaluation.
n_pts_per_ray_training: The number of points sampled along each ray during training.
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image
grid. Given a batch of image grids, this many is sampled from each.
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
defined.
n_rays_total_training: (optional) How many rays in total to sample from the entire
batch of provided image grid. The result is as if `n_rays_total_training`
cameras/image grids were sampled with replacement from the cameras / image grids
provided and for every camera one ray was sampled.
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
defined, to use you have to set `n_rays_per_image` to None.
Used only for EvaluationMode.TRAINING.
stratified_point_sampling_training: if set, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
stratified_point_sampling_evaluation: Same as above but for evaluation.
cast_ray_bundle_as_cone: If True, the sampling will generate the bins and radii
attribute of ImplicitronRayBundle. The `bins` contain the z-coordinate
(=depth) of each ray in world units and are of shape
`(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation + 1)`
while `lengths` is equal to the midpoint of the bins:
(0.5 * (bins[..., 1:] + bins[..., :-1]).
If False, `bins` is None, `radii` is None and `lengths` contains
the z-coordinate (=depth) of each ray in world units and are of shape
`(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation)`
Raises:
TypeError: if cast_ray_bundle_as_cone is set to True and n_rays_total_training
is not None will result in an error. HeterogeneousRayBundle is
not supported for conical frustum computation yet.
"""
image_width: int = 400
image_height: int = 400
sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid"
n_pts_per_ray_training: int = 64
n_pts_per_ray_evaluation: int = 64
n_rays_per_image_sampled_from_mask: Optional[int] = 1024
n_rays_total_training: Optional[int] = None
# stratified sampling vs taking points at deterministic offsets
stratified_point_sampling_training: bool = True
stratified_point_sampling_evaluation: bool = False
cast_ray_bundle_as_cone: bool = False
def __post_init__(self):
if (self.n_rays_per_image_sampled_from_mask is not None) and (
self.n_rays_total_training is not None
):
raise ValueError(
"Cannot both define n_rays_total_training and "
"n_rays_per_image_sampled_from_mask."
)
self._sampling_mode = {
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
EvaluationMode.EVALUATION: RenderSamplingMode(
self.sampling_mode_evaluation
),
}
n_pts_per_ray_training = (
self.n_pts_per_ray_training + 1
if self.cast_ray_bundle_as_cone
else self.n_pts_per_ray_training
)
n_pts_per_ray_evaluation = (
self.n_pts_per_ray_evaluation + 1
if self.cast_ray_bundle_as_cone
else self.n_pts_per_ray_evaluation
)
self._training_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=n_pts_per_ray_training,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=(
self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None
),
n_rays_total=self.n_rays_total_training,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
)
self._evaluation_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=n_pts_per_ray_evaluation,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=(
self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None
),
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
)
max_y, min_y = self._training_raysampler.max_y, self._training_raysampler.min_y
max_x, min_x = self._training_raysampler.max_x, self._training_raysampler.min_x
self.pixel_height: float = (max_y - min_y) / (self.image_height - 1)
self.pixel_width: float = (max_x - min_x) / (self.image_width - 1)
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
raise NotImplementedError()
[docs]
def forward(
self,
cameras: CamerasBase,
evaluation_mode: EvaluationMode,
mask: Optional[torch.Tensor] = None,
) -> ImplicitronRayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
evaluation_mode: one of `EvaluationMode.TRAINING` or
`EvaluationMode.EVALUATION` which determines the sampling mode
that is used.
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
Defines a non-negative mask of shape
`(batch_size, image_height, image_width)` where each per-pixel
value is proportional to the probability of sampling the
corresponding pixel's ray.
Returns:
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays.
"""
sample_mask = None
if (
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
and mask is not None
):
sample_mask = torch.nn.functional.interpolate(
mask,
size=[self.image_height, self.image_width],
mode="nearest",
)[:, 0]
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
raysampler = {
EvaluationMode.TRAINING: self._training_raysampler,
EvaluationMode.EVALUATION: self._evaluation_raysampler,
}[evaluation_mode]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
ray_bundle = raysampler(
cameras=cameras,
mask=sample_mask,
min_depth=min_depth,
max_depth=max_depth,
)
if self.cast_ray_bundle_as_cone and isinstance(
ray_bundle, HeterogeneousRayBundle
):
# If this error rises it means that raysampler has among
# its arguments `n_ray_totals`. If it is the case
# then you should update the radii computation and lengths
# computation to handle padding and unpadding.
raise TypeError(
"Heterogeneous ray bundle is not supported for conical frustum computation yet"
)
elif self.cast_ray_bundle_as_cone:
# pyre-fixme[9]: pixel_hw has type `Tuple[float, float]`; used as
# `Tuple[Union[Tensor, Module], Union[Tensor, Module]]`.
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
return ImplicitronRayBundle(
directions=ray_bundle.directions,
origins=ray_bundle.origins,
lengths=None,
xys=ray_bundle.xys,
bins=ray_bundle.lengths,
pixel_radii_2d=pixel_radii_2d,
)
return ImplicitronRayBundle(
directions=ray_bundle.directions,
origins=ray_bundle.origins,
lengths=ray_bundle.lengths,
xys=ray_bundle.xys,
camera_counts=getattr(ray_bundle, "camera_counts", None),
camera_ids=getattr(ray_bundle, "camera_ids", None),
)
[docs]
@registry.register
class AdaptiveRaySampler(AbstractMaskRaySampler):
"""
Adaptively samples points on each ray between near and far planes whose
depths are determined based on the distance from the camera center
to a predefined scene center.
More specifically,
`min_depth = max(
(self.scene_center-camera_center).norm() - self.scene_extent, eps
)` and
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
This sampling is ideal for object-centric scenes whose contents are
centered around a known `self.scene_center` and fit into a bounding sphere
with a radius of `self.scene_extent`.
Args:
scene_center: The xyz coordinates of the center of the scene used
along with `scene_extent` to compute the min and max depth planes
for sampling ray-points.
scene_extent: The radius of the scene bounding box centered at `scene_center`.
"""
scene_extent: float = 8.0
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0)
def __post_init__(self):
super().__post_init__()
if self.scene_extent <= 0.0:
raise ValueError("Adaptive raysampler requires self.scene_extent > 0.")
self._scene_center = torch.FloatTensor(self.scene_center)
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
"""
Returns the adaptively calculated near/far planes.
"""
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
cameras, self._scene_center, self.scene_extent
)
return float(min_depth[0]), float(max_depth[0])
[docs]
@registry.register
class NearFarRaySampler(AbstractMaskRaySampler):
"""
Samples a fixed number of points between fixed near and far z-planes.
Specifically, samples points along each ray with approximately uniform spacing
of z-coordinates between the minimum depth `self.min_depth` and the maximum depth
`self.max_depth`. This sampling is useful for rendering scenes where the camera is
in a constant distance from the focal point of the scene.
Args:
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
"""
min_depth: float = 0.1
max_depth: float = 8.0
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
"""
Returns the stored near/far planes.
"""
return self.min_depth, self.max_depth
[docs]
def compute_radii(
cameras: CamerasBase,
xy_grid: torch.Tensor,
pixel_hw_ndc: Tuple[float, float],
) -> torch.Tensor:
"""
Compute radii of conical frustums in world coordinates.
Args:
cameras: cameras object representing a batch of cameras.
xy_grid: torch.tensor grid of image xy coords.
pixel_hw_ndc: pixel height and width in NDC
Returns:
radii: A tensor of shape `(..., 1)` radii of a cone.
"""
batch_size = xy_grid.shape[0]
spatial_size = xy_grid.shape[1:-1]
n_rays_per_image = spatial_size.numel()
xy = xy_grid.view(batch_size, n_rays_per_image, 2)
# [batch_size, 3 * n_rays_per_image, 2]
xy = torch.cat(
[
xy,
# Will allow to find the norm on the x axis
xy + torch.tensor([pixel_hw_ndc[1], 0], device=xy.device),
# Will allow to find the norm on the y axis
xy + torch.tensor([0, pixel_hw_ndc[0]], device=xy.device),
],
dim=1,
)
# [batch_size, 3 * n_rays_per_image, 3]
xyz = torch.cat(
(
xy,
xy.new_ones(batch_size, 3 * n_rays_per_image, 1),
),
dim=-1,
)
# unproject the points
unprojected_xyz = cameras.unproject_points(xyz, from_ndc=True)
plane_world, plane_world_dx, plane_world_dy = torch.split(
unprojected_xyz, n_rays_per_image, dim=1
)
# Distance from each unit-norm direction vector to its neighbors.
dx_norm = torch.linalg.norm(plane_world_dx - plane_world, dim=-1, keepdims=True)
dy_norm = torch.linalg.norm(plane_world_dy - plane_world, dim=-1, keepdims=True)
# Cut the distance in half to obtain the base radius: (dx_norm + dy_norm) * 0.5
# Scale it by 2/12**0.5 to match the variance of the pixel’s footprint
radii = (dx_norm + dy_norm) / 12**0.5
return radii.view(batch_size, *spatial_size, 1)