# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import math
import warnings
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...camera_conversions import _pulsar_from_cameras_projection
from ...cameras import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
)
from ..compositor import AlphaCompositor, NormWeightedCompositor
from ..rasterizer import PointsRasterizer
from .renderer import Renderer as PulsarRenderer
def _ensure_float_tensor(val_in, device):
"""Make sure that the value provided is wrapped a PyTorch float tensor."""
if not isinstance(val_in, torch.Tensor):
val_out = torch.tensor(val_in, dtype=torch.float32, device=device).reshape((1,))
else:
val_out = val_in.to(torch.float32).to(device).reshape((1,))
return val_out
[docs]
class PulsarPointsRenderer(nn.Module):
"""
This renderer is a PyTorch3D interface wrapper around the pulsar renderer.
It provides an interface consistent with PyTorch3D Pointcloud rendering.
It will extract all necessary information from the rasterizer and compositor
objects and convert them to the pulsar required format, then invoke rendering
in the pulsar renderer. All gradients are handled appropriately through the
wrapper and the wrapper should provide equivalent results to using the pulsar
renderer directly.
"""
[docs]
def __init__(
self,
rasterizer: PointsRasterizer,
compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None,
n_channels: int = 3,
max_num_spheres: int = int(1e6), # noqa: B008
**kwargs,
) -> None:
"""
rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
compositor (ignored): Only keeping this for interface consistency. Default: None.
n_channels (int): The number of channels of the resulting image. Default: 3.
max_num_spheres (int): The maximum number of spheres intended to render with
this renderer. Default: 1e6.
kwargs (Any): kwargs to pass on to the pulsar renderer.
See `pytorch3d.renderer.points.pulsar.renderer.Renderer` for all options.
"""
super().__init__()
self.rasterizer = rasterizer
if compositor is not None:
warnings.warn(
"Creating a `PulsarPointsRenderer` with a compositor object! "
"This object is ignored and just allowed as an argument for interface "
"compatibility."
)
# Initialize the pulsar renderers.
if not isinstance(
rasterizer.cameras,
(
FoVOrthographicCameras,
FoVPerspectiveCameras,
PerspectiveCameras,
OrthographicCameras,
),
):
raise ValueError(
"Only FoVPerspectiveCameras, PerspectiveCameras, "
"FoVOrthographicCameras and OrthographicCameras are supported "
"by the pulsar backend."
)
if isinstance(rasterizer.raster_settings.image_size, tuple):
height, width = rasterizer.raster_settings.image_size
else:
width = rasterizer.raster_settings.image_size
height = rasterizer.raster_settings.image_size
# Making sure about integer types.
width = int(width)
height = int(height)
max_num_spheres = int(max_num_spheres)
orthogonal_projection = isinstance(
rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
)
n_channels = int(n_channels)
self.renderer = PulsarRenderer(
width=width,
height=height,
max_num_balls=max_num_spheres,
orthogonal_projection=orthogonal_projection,
right_handed_system=False,
n_channels=n_channels,
**kwargs,
)
def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool:
"""
Verify internal configuration state with kwargs and pointclouds.
This method will raise ValueError's for any inconsistencies found. It
returns whether an orthogonal projection will be used.
"""
if "gamma" not in kwargs.keys():
raise ValueError(
"gamma is a required keyword argument for the PulsarPointsRenderer!"
)
if (
len(point_clouds) != len(self.rasterizer.cameras)
and len(self.rasterizer.cameras) != 1
):
raise ValueError(
(
"The len(point_clouds) must either be equal to len(rasterizer.cameras) or "
"only one camera must be used. len(point_clouds): %d, "
"len(rasterizer.cameras): %d."
)
% (
len(point_clouds),
len(self.rasterizer.cameras),
)
)
# Make sure the rasterizer and cameras objects have no
# changes that can't be matched.
orthogonal_projection = isinstance(
self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
)
if orthogonal_projection != self.renderer._renderer.orthogonal:
raise ValueError(
"The camera type can not be changed after renderer initialization! "
"Current camera orthogonal: %r. Original orthogonal: %r."
) % (orthogonal_projection, self.renderer._renderer.orthogonal)
image_size = self.rasterizer.raster_settings.image_size
if isinstance(image_size, tuple):
expected_height, expected_width = image_size
else:
expected_height = expected_width = image_size
if expected_width != self.renderer._renderer.width:
raise ValueError(
(
"The rasterizer width can not be changed after renderer "
"initialization! Current width: %s. Original width: %d."
)
% (
expected_width,
self.renderer._renderer.width,
)
)
if expected_height != self.renderer._renderer.height:
raise ValueError(
(
"The rasterizer height can not be changed after renderer "
"initialization! Current height: %s. Original height: %d."
)
% (
expected_height,
self.renderer._renderer.height,
)
)
return orthogonal_projection
def _extract_intrinsics( # noqa: C901
self, orthogonal_projection, kwargs, cloud_idx, device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]:
"""
Translate the camera intrinsics from PyTorch3D format to pulsar format.
"""
# Shorthand:
cameras = self.rasterizer.cameras
if orthogonal_projection:
focal_length = torch.zeros((1,), dtype=torch.float32)
if isinstance(cameras, FoVOrthographicCameras):
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
max_y = kwargs.get("max_y", cameras.max_y)[cloud_idx]
min_y = kwargs.get("min_y", cameras.min_y)[cloud_idx]
max_x = kwargs.get("max_x", cameras.max_x)[cloud_idx]
min_x = kwargs.get("min_x", cameras.min_x)[cloud_idx]
if max_y != -min_y:
raise ValueError(
"The orthographic camera must be centered around 0. "
f"Max is {max_y} and min is {min_y}."
)
if max_x != -min_x:
raise ValueError(
"The orthographic camera must be centered around 0. "
f"Max is {max_x} and min is {min_x}."
)
if not torch.all(
kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx] == 1.0
):
raise ValueError(
"The orthographic camera scale must be ((1.0, 1.0, 1.0),). "
f"{kwargs.get('scale_xyz', cameras.scale_xyz)[cloud_idx]}."
)
sensor_width = max_x - min_x
if not sensor_width > 0.0:
raise ValueError(
f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950
)
principal_point_x, principal_point_y = (
torch.zeros((1,), dtype=torch.float32),
torch.zeros((1,), dtype=torch.float32),
)
else:
# Currently, this means it must be an 'OrthographicCameras' object.
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
cloud_idx
]
if (
focal_length_conf.numel() == 2
and focal_length_conf[0] * self.renderer._renderer.width
- focal_length_conf[1] * self.renderer._renderer.height
> 1e-5
):
raise ValueError(
"Pulsar only supports a single focal length! "
"Provided: %s." % (str(focal_length_conf))
)
if focal_length_conf.numel() == 2:
sensor_width = 2.0 / focal_length_conf[0]
else:
if focal_length_conf.numel() != 1:
raise ValueError(
"Focal length not parsable: %s." % (str(focal_length_conf))
)
sensor_width = 2.0 / focal_length_conf
if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
raise ValueError(
"pulsar needs znear and zfar values for "
"the OrthographicCameras. Please provide them as keyword "
"argument to the forward method."
)
znear = kwargs["znear"][cloud_idx]
zfar = kwargs["zfar"][cloud_idx]
principal_point_x = (
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
* 0.5
* self.renderer._renderer.width
)
principal_point_y = (
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
* 0.5
* self.renderer._renderer.height
)
else:
if not isinstance(cameras, PerspectiveCameras):
# Create a virtual focal length that is closer than znear.
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
focal_length = znear - 1e-6
# Create a sensor size that matches the expected fov assuming this f.
afov = kwargs.get("fov", cameras.fov)[cloud_idx]
if kwargs.get("degrees", cameras.degrees):
afov *= math.pi / 180.0
sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length
if not (
kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
- self.renderer._renderer.width / self.renderer._renderer.height
< 1e-6
):
raise ValueError(
"The aspect ratio ("
f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) "
"must agree with the resolution width / height ("
f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950
)
principal_point_x, principal_point_y = (
torch.zeros((1,), dtype=torch.float32),
torch.zeros((1,), dtype=torch.float32),
)
else:
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
cloud_idx
]
if (
focal_length_conf.numel() == 2
and focal_length_conf[0] * self.renderer._renderer.width
- focal_length_conf[1] * self.renderer._renderer.height
> 1e-5
):
raise ValueError(
"Pulsar only supports a single focal length! "
"Provided: %s." % (str(focal_length_conf))
)
if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
raise ValueError(
"pulsar needs znear and zfar values for "
"the PerspectiveCameras. Please provide them as keyword "
"argument to the forward method."
)
znear = kwargs["znear"][cloud_idx]
zfar = kwargs["zfar"][cloud_idx]
if focal_length_conf.numel() == 2:
focal_length_px = focal_length_conf[0]
else:
if focal_length_conf.numel() != 1:
raise ValueError(
"Focal length not parsable: %s." % (str(focal_length_conf))
)
focal_length_px = focal_length_conf
focal_length = torch.tensor(
[
znear - 1e-6,
],
dtype=torch.float32,
device=focal_length_px.device,
)
sensor_width = focal_length / focal_length_px * 2.0
principal_point_x = (
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
* 0.5
* self.renderer._renderer.width
)
principal_point_y = (
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
* 0.5
* self.renderer._renderer.height
)
focal_length = _ensure_float_tensor(focal_length, device)
sensor_width = _ensure_float_tensor(sensor_width, device)
principal_point_x = _ensure_float_tensor(principal_point_x, device)
principal_point_y = _ensure_float_tensor(principal_point_y, device)
znear = _ensure_float_tensor(znear, device)
zfar = _ensure_float_tensor(zfar, device)
return (
focal_length,
sensor_width,
principal_point_x,
principal_point_y,
znear,
zfar,
)
def _extract_extrinsics(
self, kwargs, cloud_idx
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Extract the extrinsic information from the kwargs for a specific point cloud.
Instead of implementing a direct translation from the PyTorch3D to the Pulsar
camera model, we chain the two conversions of PyTorch3D->OpenCV and
OpenCV->Pulsar for better maintainability (PyTorch3D->OpenCV is maintained and
tested by the core PyTorch3D team, whereas OpenCV->Pulsar is maintained and
tested by the Pulsar team).
"""
# Shorthand:
cameras = self.rasterizer.cameras
R = kwargs.get("R", cameras.R)[cloud_idx]
T = kwargs.get("T", cameras.T)[cloud_idx]
tmp_cams = PerspectiveCameras(
R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device
)
size_tensor = torch.tensor(
[[self.renderer._renderer.height, self.renderer._renderer.width]]
)
pulsar_cam = _pulsar_from_cameras_projection(tmp_cams, size_tensor)
cam_pos = pulsar_cam[0, :3]
cam_rot = pulsar_cam[0, 3:9]
return cam_pos, cam_rot
def _get_vert_rad(
self, vert_pos, cam_pos, orthogonal_projection, focal_length, kwargs, cloud_idx
) -> torch.Tensor:
"""
Get point radiuses.
These can be depending on the camera position in case of a perspective
transform.
"""
# Normalize point radiuses.
# `self.rasterizer.raster_settings.radius` can either be a float
# or itself a tensor.
raster_rad = self.rasterizer.raster_settings.radius
if kwargs.get("radius_world", False):
return raster_rad
if (
isinstance(raster_rad, torch.Tensor)
and raster_rad.numel() > 1
and raster_rad.ndim > 1
):
# In this case it must be a batched torch tensor.
raster_rad = raster_rad[cloud_idx]
if orthogonal_projection:
vert_rad = (
torch.ones(
(vert_pos.shape[0],), dtype=torch.float32, device=vert_pos.device
)
* raster_rad
)
else:
point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False)
vert_rad = raster_rad / focal_length.to(vert_pos.device) * point_dists
if isinstance(self.rasterizer.cameras, PerspectiveCameras):
# NDC normalization happens through adjusted focal length.
pass
else:
vert_rad = vert_rad / 2.0 # NDC normalization.
return vert_rad
# point_clouds is not typed to avoid a cyclic dependency.
[docs]
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
"""
Get the rendering of the provided `Pointclouds`.
The number of point clouds in the `Pointclouds` object determines the
number of resulting images. The provided cameras can be either 1 or equal
to the number of pointclouds (in the first case, the same camera will be
used for all clouds, in the latter case each point cloud will be rendered
with the corresponding camera).
The following kwargs are support from PyTorch3D (depending on the selected
camera model potentially overriding camera parameters):
radius_world (bool): use the provided radiuses from the raster_settings
plain as radiuses in world space. Default: False.
znear (Iterable[float]): near geometry cutoff. Is required for
OrthographicCameras and PerspectiveCameras.
zfar (Iterable[float]): far geometry cutoff. Is required for
OrthographicCameras and PerspectiveCameras.
R (torch.Tensor): [Bx3x3] camera rotation matrices.
T (torch.Tensor): [Bx3] camera translation vectors.
principal_point (torch.Tensor): [Bx2] camera intrinsic principal
point offset vectors.
focal_length (torch.Tensor): [Bx1] camera intrinsic focal lengths.
aspect_ratio (Iterable[float]): camera aspect ratios.
fov (Iterable[float]): camera FOVs.
degrees (bool): whether FOVs are specified in degrees or
radians.
min_x (Iterable[float]): minimum x for the FoVOrthographicCameras.
max_x (Iterable[float]): maximum x for the FoVOrthographicCameras.
min_y (Iterable[float]): minimum y for the FoVOrthographicCameras.
max_y (Iterable[float]): maximum y for the FoVOrthographicCameras.
The following kwargs are supported from pulsar:
gamma (float): The gamma value to use. This defines the transparency for
differentiability (see pulsar paper for details). Must be in [1., 1e-5]
with 1.0 being mostly transparent. This keyword argument is *required*!
bg_col (torch.Tensor): The background color. Must be a tensor on the same
device as the point clouds, with as many channels as features (no batch
dimension - it is the same for all images in the batch).
Default: 0.0 for all channels.
percent_allowed_difference (float): a value in [0., 1.[ with the maximum
allowed difference in channel space. This is used to speed up the
computation. Default: 0.01.
max_n_hits (int): a hard limit on the number of sphere hits per ray.
Default: max int.
mode (int): render mode in {0, 1}. 0: render image; 1: render hit map.
"""
orthogonal_projection: bool = self._conf_check(point_clouds, kwargs)
# Get access to inputs. We're using the list accessor and process
# them sequentially.
position_list = point_clouds.points_list()
features_list = point_clouds.features_list()
# Result list.
images = []
for cloud_idx, (vert_pos, vert_col) in enumerate(
zip(position_list, features_list)
):
# Get extrinsics.
cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
# Get intrinsics.
(
focal_length,
sensor_width,
principal_point_x,
principal_point_y,
znear,
zfar,
) = self._extract_intrinsics(
orthogonal_projection, kwargs, cloud_idx, cam_pos.device
)
# Put everything together.
cam_params = torch.cat(
(
cam_pos,
cam_rot.to(cam_pos.device),
torch.cat(
[
focal_length,
sensor_width,
principal_point_x,
principal_point_y,
],
),
)
)
# Get point radiuses (can depend on camera position).
vert_rad = self._get_vert_rad(
vert_pos,
cam_pos,
orthogonal_projection,
focal_length,
kwargs,
cloud_idx,
)
# Clean kwargs for passing on.
gamma = kwargs["gamma"][cloud_idx]
if "first_R_then_T" in kwargs.keys():
raise ValueError("`first_R_then_T` is not supported in this interface.")
otherargs = {
argn: argv
for argn, argv in kwargs.items()
if argn
not in [
"radius_world",
"gamma",
"znear",
"zfar",
"R",
"T",
"principal_point",
"focal_length",
"aspect_ratio",
"fov",
"degrees",
"min_x",
"max_x",
"min_y",
"max_y",
]
}
# background color
if "bg_col" not in otherargs:
bg_col = torch.zeros(
vert_col.shape[1], device=cam_params.device, dtype=torch.float32
)
otherargs["bg_col"] = bg_col
# Go!
images.append(
self.renderer(
vert_pos=vert_pos,
vert_col=vert_col,
vert_rad=vert_rad,
cam_params=cam_params,
gamma=gamma,
max_depth=zfar,
min_depth=znear,
**otherargs,
).flip(dims=[0])
)
return torch.stack(images, dim=0)