# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
from .rasterize_meshes import rasterize_meshes
# Class to store the outputs of mesh rasterization
[docs]class Fragments(NamedTuple):
pix_to_face: torch.Tensor
zbuf: torch.Tensor
bary_coords: torch.Tensor
dists: torch.Tensor
[docs]@dataclass
class RasterizationSettings:
"""
Class to store the mesh rasterization params with defaults
Members:
image_size: Either common height and width or (height, width), in pixels.
blur_radius: Float distance in the range [0, 2] used to expand the face
bounding boxes for rasterization. Setting blur radius
results in blurred edges around the shape instead of a
hard boundary. Set to 0 for no blur.
faces_per_pixel: (int) Number of faces to keep track of per pixel.
We return the nearest faces_per_pixel faces along the z-axis.
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
bin_size=0 uses naive rasterization; setting bin_size=None attempts
to set it heuristically based on the shape of the input. This should
not affect the output, but can affect the speed of the forward pass.
max_faces_per_bin: Only applicable when using coarse-to-fine
rasterization (bin_size != 0); this is the maximum number of faces
allowed within each bin. This should not affect the output values,
but can affect the memory usage in the forward pass.
Setting max_faces_per_bin=None attempts to set with a heuristic.
perspective_correct: Whether to apply perspective correction when
computing barycentric coordinates for pixels.
None (default) means make correction if the camera uses perspective.
clip_barycentric_coords: Whether, after any perspective correction
is applied but before the depth is calculated (e.g. for
z clipping), to "correct" a location outside the face (i.e. with
a negative barycentric coordinate) to a position on the edge of the
face. None (default) means clip if blur_radius > 0, which is a condition
under which such outside-face-points are likely.
cull_backfaces: Whether to only rasterize mesh faces which are
visible to the camera. This assumes that vertices of
front-facing triangles are ordered in an anti-clockwise
fashion, and triangles that face away from the camera are
in a clockwise order relative to the current view
direction. NOTE: This will only work if the mesh faces are
consistently defined with counter-clockwise ordering when
viewed from the outside.
z_clip_value: if not None, then triangles will be clipped (and possibly
subdivided into smaller triangles) such that z >= z_clip_value.
This avoids camera projections that go to infinity as z->0.
Default is None as clipping affects rasterization speed and
should only be turned on if explicitly needed.
See clip.py for all the extra computation that is required.
cull_to_frustum: Whether to cull triangles outside the view frustum.
Culling involves removing all faces which fall outside view frustum.
Default is False for performance as often not needed.
"""
image_size: Union[int, Tuple[int, int]] = 256
blur_radius: float = 0.0
faces_per_pixel: int = 1
bin_size: Optional[int] = None
max_faces_per_bin: Optional[int] = None
perspective_correct: Optional[bool] = None
clip_barycentric_coords: Optional[bool] = None
cull_backfaces: bool = False
z_clip_value: Optional[float] = None
cull_to_frustum: bool = False
[docs]class MeshRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of heterogeneous
Meshes.
"""
[docs] def __init__(self, cameras=None, raster_settings=None) -> None:
"""
Args:
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-ndc transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = RasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
[docs] def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
self.cameras = self.cameras.to(device)
return self
[docs] def forward(self, meshes_world, **kwargs) -> Fragments:
"""
Args:
meshes_world: a Meshes object representing a batch of meshes with
coordinates in world space.
Returns:
Fragments: Rasterization outputs as a named tuple.
"""
meshes_proj = self.transform(meshes_world, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
# By default, turn on clip_barycentric_coords if blur_radius > 0.
# When blur_radius > 0, a face can be matched to a pixel that is outside the
# face, resulting in negative barycentric coordinates.
clip_barycentric_coords = raster_settings.clip_barycentric_coords
if clip_barycentric_coords is None:
clip_barycentric_coords = raster_settings.blur_radius > 0.0
# If not specified, infer perspective_correct and z_clip_value from the camera
cameras = kwargs.get("cameras", self.cameras)
if raster_settings.perspective_correct is not None:
perspective_correct = raster_settings.perspective_correct
else:
perspective_correct = cameras.is_perspective()
if raster_settings.z_clip_value is not None:
z_clip = raster_settings.z_clip_value
else:
znear = cameras.get_znear()
if isinstance(znear, torch.Tensor):
znear = znear.min().item()
z_clip = None if not perspective_correct or znear is None else znear / 2
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_proj,
image_size=raster_settings.image_size,
blur_radius=raster_settings.blur_radius,
faces_per_pixel=raster_settings.faces_per_pixel,
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
clip_barycentric_coords=clip_barycentric_coords,
perspective_correct=perspective_correct,
cull_backfaces=raster_settings.cull_backfaces,
z_clip_value=z_clip,
cull_to_frustum=raster_settings.cull_to_frustum,
)
return Fragments(
pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists
)