# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from pytorch3d.renderer.cameras import try_get_projection_transform
from .rasterize_meshes import rasterize_meshes
[docs]
@dataclass(frozen=True)
class Fragments:
"""
A dataclass representing the outputs of a rasterizer. Can be detached from the
computational graph in order to stop the gradients from flowing through the
rasterizer.
Members:
pix_to_face:
LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving
the indices of the nearest faces at each pixel, sorted in ascending
z-order. Concretely ``pix_to_face[n, y, x, k] = f`` means that
``faces_verts[f]`` is the kth closest face (in the z-direction) to pixel
(y, x). Pixels that are hit by fewer than faces_per_pixel are padded with
-1.
zbuf:
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
the NDC z-coordinates of the nearest faces at each pixel, sorted in
ascending z-order. Concretely, if ``pix_to_face[n, y, x, k] = f`` then
``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than
faces_per_pixel are padded with -1.
bary_coords:
FloatTensor of shape (N, image_size, image_size, faces_per_pixel, 3)
giving the barycentric coordinates in NDC units of the nearest faces at
each pixel, sorted in ascending z-order. Concretely, if ``pix_to_face[n,
y, x, k] = f`` then ``[w0, w1, w2] = barycentric[n, y, x, k]`` gives the
barycentric coords for pixel (y, x) relative to the face defined by
``face_verts[f]``. Pixels hit by fewer than faces_per_pixel are padded
with -1.
dists:
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
the signed Euclidean distance (in NDC units) in the x/y plane of each
point closest to the pixel. Concretely if ``pix_to_face[n, y, x, k] = f``
then ``pix_dists[n, y, x, k]`` is the squared distance between the pixel
(y, x) and the face given by vertices ``face_verts[f]``. Pixels hit with
fewer than ``faces_per_pixel`` are padded with -1.
"""
pix_to_face: torch.Tensor
zbuf: torch.Tensor
bary_coords: torch.Tensor
dists: Optional[torch.Tensor]
[docs]
def detach(self) -> "Fragments":
return Fragments(
pix_to_face=self.pix_to_face,
zbuf=self.zbuf.detach(),
bary_coords=self.bary_coords.detach(),
dists=self.dists.detach() if self.dists is not None else self.dists,
)
[docs]
@dataclass
class RasterizationSettings:
"""
Class to store the mesh rasterization params with defaults
Members:
image_size: Either common height and width or (height, width), in pixels.
blur_radius: Float distance in the range [0, 2] used to expand the face
bounding boxes for rasterization. Setting blur radius
results in blurred edges around the shape instead of a
hard boundary. Set to 0 for no blur.
faces_per_pixel: (int) Number of faces to keep track of per pixel.
We return the nearest faces_per_pixel faces along the z-axis.
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
bin_size=0 uses naive rasterization; setting bin_size=None attempts
to set it heuristically based on the shape of the input. This should
not affect the output, but can affect the speed of the forward pass.
max_faces_opengl: Max number of faces in any mesh we will rasterize. Used only by
MeshRasterizerOpenGL to pre-allocate OpenGL memory.
max_faces_per_bin: Only applicable when using coarse-to-fine
rasterization (bin_size != 0); this is the maximum number of faces
allowed within each bin. This should not affect the output values,
but can affect the memory usage in the forward pass.
Setting max_faces_per_bin=None attempts to set with a heuristic.
perspective_correct: Whether to apply perspective correction when
computing barycentric coordinates for pixels.
None (default) means make correction if the camera uses perspective.
clip_barycentric_coords: Whether, after any perspective correction
is applied but before the depth is calculated (e.g. for
z clipping), to "correct" a location outside the face (i.e. with
a negative barycentric coordinate) to a position on the edge of the
face. None (default) means clip if blur_radius > 0, which is a condition
under which such outside-face-points are likely.
cull_backfaces: Whether to only rasterize mesh faces which are
visible to the camera. This assumes that vertices of
front-facing triangles are ordered in an anti-clockwise
fashion, and triangles that face away from the camera are
in a clockwise order relative to the current view
direction. NOTE: This will only work if the mesh faces are
consistently defined with counter-clockwise ordering when
viewed from the outside.
z_clip_value: if not None, then triangles will be clipped (and possibly
subdivided into smaller triangles) such that z >= z_clip_value.
This avoids camera projections that go to infinity as z->0.
Default is None as clipping affects rasterization speed and
should only be turned on if explicitly needed.
See clip.py for all the extra computation that is required.
cull_to_frustum: Whether to cull triangles outside the view frustum.
Culling involves removing all faces which fall outside view frustum.
Default is False for performance as often not needed.
"""
image_size: Union[int, Tuple[int, int]] = 256
blur_radius: float = 0.0
faces_per_pixel: int = 1
bin_size: Optional[int] = None
max_faces_opengl: int = 10_000_000
max_faces_per_bin: Optional[int] = None
perspective_correct: Optional[bool] = None
clip_barycentric_coords: Optional[bool] = None
cull_backfaces: bool = False
z_clip_value: Optional[float] = None
cull_to_frustum: bool = False
[docs]
class MeshRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of heterogeneous
Meshes.
"""
[docs]
def __init__(self, cameras=None, raster_settings=None) -> None:
"""
Args:
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-ndc transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = RasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
[docs]
def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
return self
[docs]
def forward(self, meshes_world, **kwargs) -> Fragments:
"""
Args:
meshes_world: a Meshes object representing a batch of meshes with
coordinates in world space.
Returns:
Fragments: Rasterization outputs as a named tuple.
"""
meshes_proj = self.transform(meshes_world, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
# By default, turn on clip_barycentric_coords if blur_radius > 0.
# When blur_radius > 0, a face can be matched to a pixel that is outside the
# face, resulting in negative barycentric coordinates.
clip_barycentric_coords = raster_settings.clip_barycentric_coords
if clip_barycentric_coords is None:
clip_barycentric_coords = raster_settings.blur_radius > 0.0
# If not specified, infer perspective_correct and z_clip_value from the camera
cameras = kwargs.get("cameras", self.cameras)
if raster_settings.perspective_correct is not None:
perspective_correct = raster_settings.perspective_correct
else:
perspective_correct = cameras.is_perspective()
if raster_settings.z_clip_value is not None:
z_clip = raster_settings.z_clip_value
else:
znear = cameras.get_znear()
if isinstance(znear, torch.Tensor):
znear = znear.min().item()
z_clip = None if not perspective_correct or znear is None else znear / 2
# By default, turn on clip_barycentric_coords if blur_radius > 0.
# When blur_radius > 0, a face can be matched to a pixel that is outside the
# face, resulting in negative barycentric coordinates.
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_proj,
image_size=raster_settings.image_size,
blur_radius=raster_settings.blur_radius,
faces_per_pixel=raster_settings.faces_per_pixel,
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
clip_barycentric_coords=clip_barycentric_coords,
perspective_correct=perspective_correct,
cull_backfaces=raster_settings.cull_backfaces,
z_clip_value=z_clip,
cull_to_frustum=raster_settings.cull_to_frustum,
)
return Fragments(
pix_to_face=pix_to_face,
zbuf=zbuf,
bary_coords=bary_coords,
dists=dists,
)