#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import torch
import torch.nn as nn
# A renderer class should be initialized with a
# function for rasterization and a function for compositing.
# The rasterizer should:
# - transform inputs from world -> screen space
# - rasterize inputs
# - return fragments
# The compositor can take fragments as input along with any other properties of
# the scene and generate images.
# E.g. rasterize inputs and then shade
# fragments = self.rasterize(point_clouds)
# images = self.compositor(fragments, point_clouds)
# return images
class PointsRenderer(nn.Module):
A class for rendering a batch of points. The class should
be initialized with a rasterizer and compositor class which each have a forward
The points are rendered with with varying alpha (weights) values depending on
the distance of the pixel center to the true point in the xy plane. The purpose
of this is to soften the hard decision boundary, for differentiability.
See Section 3.2 of "SynSin: End-to-end View Synthesis from a Single Image"
(https://arxiv.org/pdf/1912.08804.pdf) for more details.
def __init__(self, rasterizer, compositor) -> None:
self.rasterizer = rasterizer
self.compositor = compositor
def to(self, device):
# Manually move to device rasterizer as the cameras
# within the class are not of type nn.Module
self.rasterizer = self.rasterizer.to(device)
self.compositor = self.compositor.to(device)
return self
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(point_clouds, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = self.rasterizer.raster_settings.radius
dists2 = fragments.dists.permute(0, 3, 1, 2)
weights = 1 - dists2 / (r * r)
images = self.compositor(
fragments.idx.long().permute(0, 3, 1, 2),
point_clouds.features_packed().permute(1, 0),
# permute so image comes at the end
images = images.permute(0, 2, 3, 1)
return images