#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
from pytorch3d.renderer.cameras import try_get_projection_transform
from pytorch3d.structures import Pointclouds
from .rasterize_points import rasterize_points
[docs]
class PointFragments(NamedTuple):
"""
Class to store the outputs of point rasterization
Members:
idx: int32 Tensor of shape (N, image_size, image_size, points_per_pixel)
giving the indices of the nearest points at each pixel, in ascending
z-order. Concretely `idx[n, y, x, k] = p` means that `points[p]` is the kth
closest point (along the z-direction) to pixel (y, x) - note that points
represents the packed points of shape (P, 3).
Pixels that are hit by fewer than points_per_pixel are padded with -1.
zbuf: Tensor of shape (N, image_size, image_size, points_per_pixel)
giving the z-coordinates of the nearest points at each pixel, sorted in
z-order. Concretely, if `idx[n, y, x, k] = p` then
`zbuf[n, y, x, k] = points[n, p, 2]`. Pixels hit by fewer than
points_per_pixel are padded with -1.
dists: Tensor of shape (N, image_size, image_size, points_per_pixel)
giving the squared Euclidean distance (in NDC units) in the x/y plane
for each point closest to the pixel. Concretely if `idx[n, y, x, k] = p`
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
than points_per_pixel are padded with -1.
"""
idx: torch.Tensor
zbuf: torch.Tensor
dists: torch.Tensor
[docs]
@dataclass
class PointsRasterizationSettings:
"""
Class to store the point rasterization params with defaults
Members:
image_size: Either common height and width or (height, width), in pixels.
radius: The radius (in NDC units) of each disk to be rasterized.
This can either be a float in which case the same radius is used
for each point, or a torch.Tensor of shape (N, P) giving a radius
per point in the batch.
points_per_pixel: (int) Number of points to keep track of per pixel.
We return the nearest points_per_pixel points along the z-axis.
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
bin_size=0 uses naive rasterization; setting bin_size=None attempts
to set it heuristically based on the shape of the input. This should
not affect the output, but can affect the speed of the forward pass.
max_points_per_bin: Only applicable when using coarse-to-fine
rasterization (bin_size != 0); this is the maximum number of points
allowed within each bin. This should not affect the output values,
but can affect the memory usage in the forward pass.
Setting max_points_per_bin=None attempts to set with a heuristic.
"""
image_size: Union[int, Tuple[int, int]] = 256
radius: Union[float, torch.Tensor] = 0.01
points_per_pixel: int = 8
bin_size: Optional[int] = None
max_points_per_bin: Optional[int] = None
[docs]
class PointsRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of pointclouds.
"""
[docs]
def __init__(self, cameras=None, raster_settings=None) -> None:
"""
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-ndc transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = PointsRasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
[docs]
def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
return self
[docs]
def forward(self, point_clouds, **kwargs) -> PointFragments:
"""
Args:
point_clouds: a set of point clouds with coordinates in world space.
Returns:
PointFragments: Rasterization outputs as a named tuple.
"""
points_proj = self.transform(point_clouds, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
idx, zbuf, dists2 = rasterize_points(
points_proj,
image_size=raster_settings.image_size,
radius=raster_settings.radius,
points_per_pixel=raster_settings.points_per_pixel,
bin_size=raster_settings.bin_size,
max_points_per_bin=raster_settings.max_points_per_bin,
)
return PointFragments(idx=idx, zbuf=zbuf, dists=dists2)