Source code for pytorch3d.renderer.points.renderer

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import torch.nn as nn

# A renderer class should be initialized with a
# function for rasterization and a function for compositing.
# The rasterizer should:
#     - transform inputs from world -> screen space
#     - rasterize inputs
#     - return fragments
# The compositor can take fragments as input along with any other properties of
# the scene and generate images.

# E.g. rasterize inputs and then shade
# fragments = self.rasterize(point_clouds)
# images = self.compositor(fragments, point_clouds)
# return images

[docs] class PointsRenderer(nn.Module): """ A class for rendering a batch of points. The class should be initialized with a rasterizer and compositor class which each have a forward function. The points are rendered with with varying alpha (weights) values depending on the distance of the pixel center to the true point in the xy plane. The purpose of this is to soften the hard decision boundary, for differentiability. See Section 3.2 of "SynSin: End-to-end View Synthesis from a Single Image" ( for more details. """ def __init__(self, rasterizer, compositor) -> None: super().__init__() self.rasterizer = rasterizer self.compositor = compositor
[docs] def to(self, device): # Manually move to device rasterizer as the cameras # within the class are not of type nn.Module self.rasterizer = self.compositor = return self
[docs] def forward(self, point_clouds, **kwargs) -> torch.Tensor: fragments = self.rasterizer(point_clouds, **kwargs) # Construct weights based on the distance of a point to the true point. # However, this could be done differently: e.g. predicted as opposed # to a function of the weights. r = self.rasterizer.raster_settings.radius dists2 = fragments.dists.permute(0, 3, 1, 2) weights = 1 - dists2 / (r * r) images = self.compositor( fragments.idx.long().permute(0, 3, 1, 2), weights, point_clouds.features_packed().permute(1, 0), **kwargs, ) # permute so image comes at the end images = images.permute(0, 2, 3, 1) return images