Source code for pytorch3d.implicitron.models.renderer.sdf_renderer

# @lint-ignore-every LICENSELINT
# Adapted from
# Copyright (c) 2020 Lior Yariv

# pyre-unsafe
import functools
from typing import List, Optional, Tuple

import torch
from omegaconf import DictConfig
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from import (
from import evaluating

from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_tracing import RayTracing
from .rgb_net import RayNormalColoringNetwork

[docs] @registry.register class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): render_features_dimensions: int = 3 object_bounding_sphere: float = 1.0 # pyre-fixme[13]: Attribute `ray_tracer` is never initialized. ray_tracer: RayTracing ray_normal_coloring_network_args: DictConfig = get_default_args_field( RayNormalColoringNetwork ) bg_color: Tuple[float, ...] = (0.0,) soft_mask_alpha: float = 50.0 def __post_init__( self, ): render_features_dimensions = self.render_features_dimensions if len(self.bg_color) not in [1, render_features_dimensions]: raise ValueError( f"Background color should have {render_features_dimensions} entries." ) run_auto_creation(self) self.ray_normal_coloring_network_args["feature_vector_size"] = ( render_features_dimensions ) self._rgb_network = RayNormalColoringNetwork( **self.ray_normal_coloring_network_args ) self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
[docs] @classmethod def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None: del args["object_bounding_sphere"]
[docs] def create_ray_tracer(self) -> None: self.ray_tracer = RayTracing( # pyre-fixme[32]: Keyword argument must be a mapping with string keys. **self.ray_tracer_args, object_bounding_sphere=self.object_bounding_sphere, )
[docs] def requires_object_mask(self) -> bool: return True
[docs] def forward( self, ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, object_mask: Optional[torch.Tensor] = None, **kwargs, ) -> RendererOutput: """ Args: ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: single element list of ImplicitFunctionWrappers which defines the implicit function to be used. evaluation_mode: one of EvaluationMode.TRAINING or EvaluationMode.EVALUATION which determines the settings used for rendering. kwargs: object_mask: BoolTensor, denoting the silhouette of the object. This is a required keyword argument for SignedDistanceFunctionRenderer Returns: instance of RendererOutput """ if len(implicit_functions) != 1: raise ValueError( "SignedDistanceFunctionRenderer supports only single pass." ) if object_mask is None: raise ValueError("Expected object_mask to be provided in the kwargs") object_mask = object_mask.bool() implicit_function = implicit_functions[0] implicit_function_gradient = functools.partial(_gradient, implicit_function) # object_mask: silhouette of the object batch_size, *spatial_size, _ = ray_bundle.lengths.shape num_pixels = prod(spatial_size) cam_loc =, -1, 3) ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3) object_mask = object_mask.reshape(batch_size, -1) with torch.no_grad(), evaluating(implicit_function): points, network_object_mask, dists = self.ray_tracer( sdf=lambda x: implicit_function(rays_points_world=x)[ :, 0 ], # TODO: get rid of this wrapper cam_loc=cam_loc, object_mask=object_mask, ray_directions=ray_dirs, ) # TODO: below, cam_loc might as well be different depth = dists.reshape(batch_size, num_pixels, 1) points = (cam_loc + depth * ray_dirs).reshape(-1, 3) sdf_output = implicit_function(rays_points_world=points)[:, 0:1] # NOTE most of the intermediate variables are flattened for # no apparent reason (here and in the ray tracer) ray_dirs = ray_dirs.reshape(-1, 3) object_mask = object_mask.reshape(-1) # TODO: move it to loss computation if evaluation_mode == EvaluationMode.TRAINING: surface_mask = network_object_mask & object_mask surface_points = points[surface_mask] surface_dists = dists[surface_mask].unsqueeze(-1) surface_ray_dirs = ray_dirs[surface_mask] surface_cam_loc = cam_loc.reshape(-1, 3)[surface_mask] surface_output = sdf_output[surface_mask] N = surface_points.shape[0] # Sample points for the eikonal loss eik_bounding_box: float = self.object_bounding_sphere n_eik_points = batch_size * num_pixels // 2 eikonal_points = torch.empty( n_eik_points, 3, # but got `Union[device, Tensor, Module]`. # pyre-fixme[6]: For 3rd argument expected `Union[None, int, str, # device]` but got `Union[device, Tensor, Module]`. device=self._bg_color.device, ).uniform_(-eik_bounding_box, eik_bounding_box) eikonal_pixel_points = points.clone() eikonal_pixel_points = eikonal_pixel_points.detach() eikonal_points =[eikonal_points, eikonal_pixel_points], 0) points_all =[surface_points, eikonal_points], dim=0) output = implicit_function(rays_points_world=surface_points) surface_sdf_values = output[ :N, 0:1 ].detach() # how is it different from sdf_output? g = implicit_function_gradient(points_all) surface_points_grad = g[:N, 0, :].clone().detach() grad_theta = g[N:, 0, :] differentiable_surface_points = _sample_network( surface_output, surface_sdf_values, surface_points_grad, surface_dists, surface_cam_loc, surface_ray_dirs, ) else: surface_mask = network_object_mask differentiable_surface_points = points[surface_mask] grad_theta = None empty_render = differentiable_surface_points.shape[0] == 0 features = implicit_function(rays_points_world=differentiable_surface_points)[ None, :, 1: ] normals_full = features.new_zeros( batch_size, *spatial_size, 3, requires_grad=empty_render ) render_full = ( features.new_ones( batch_size, *spatial_size, self.render_features_dimensions, requires_grad=empty_render, ) * self._bg_color ) mask_full = features.new_ones( batch_size, *spatial_size, 1, requires_grad=empty_render ) if not empty_render: normals = implicit_function_gradient(differentiable_surface_points)[ None, :, 0, : ] normals_full.view(-1, 3)[surface_mask] = normals render_full.view(-1, self.render_features_dimensions)[surface_mask] = ( # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. self._rgb_network( features, differentiable_surface_points[None], normals, ray_bundle, surface_mask[None, :, None], pooling_fn=None, # TODO ) ) mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid( # pyre-fixme[6]: For 1st param expected `Tensor` but got `float`. -self.soft_mask_alpha * sdf_output[~surface_mask] ) # scatter points with surface_mask points_full = points_full.view(-1, 3)[surface_mask] = differentiable_surface_points # TODO: it is sparse here but otherwise dense return RendererOutput( features=render_full, normals=normals_full, depths=depth.reshape(batch_size, *spatial_size, 1), masks=mask_full, # this is a differentiable approximation, see (7) in the paper points=points_full, aux={"grad_theta": grad_theta}, # TODO: will be moved to eikonal loss # TODO: do we need sdf_output, grad_theta? Only for loss probably )
def _sample_network( surface_output, surface_sdf_values, surface_points_grad, surface_dists, surface_cam_loc, surface_ray_dirs, eps: float = 1e-4, ): # t -> t(theta) surface_ray_dirs_0 = surface_ray_dirs.detach() surface_points_dot = torch.bmm( surface_points_grad.view(-1, 1, 3), surface_ray_dirs_0.view(-1, 3, 1) ).squeeze(-1) dot_sign = (surface_points_dot >= 0).to(surface_points_dot) * 2 - 1 surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / ( surface_points_dot.abs().clip(eps) * dot_sign ) # t(theta) -> x(theta,c,v) surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs return surface_points_theta_c_v @torch.enable_grad() def _gradient(module, rays_points_world): rays_points_world.requires_grad_(True) y = module.forward(rays_points_world=rays_points_world)[:, :1] d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=rays_points_world, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True, )[0] return gradients.unsqueeze(1)