# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import copy
import logging
from typing import List, Optional, Tuple

import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from import registry

from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput

logger = logging.getLogger(__name__)

[docs] @registry.register class LSTMRenderer(BaseRenderer, torch.nn.Module): """ Implements the learnable LSTM raymarching function from SRN [1]. This requires there to be one implicit function, and it is expected to be like SRNImplicitFunction or SRNHyperNetImplicitFunction. Settings: num_raymarch_steps: The number of LSTM raymarching steps. init_depth: Initializes the bias of the last raymarching LSTM layer so that the farthest point from the camera reaches a far z-plane that lies `init_depth` units from the camera plane. init_depth_noise_std: The standard deviation of the random normal noise added to the initial depth of each marched ray. hidden_size: The dimensionality of the LSTM's hidden state. n_feature_channels: The number of feature channels returned by the implicit_function evaluated at each raymarching step. bg_color: If supplied, used as the background color. Otherwise the pixel generator is used everywhere. This has to have length either 1 (for a constant value for all output channels) or equal to the number of output channels (which is `out_features` on the pixel generator, typically 3.) verbose: If `True`, logs raymarching debug info. References: [1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G.. "Scene representation networks: Continuous 3d-structure-aware neural scene representations." NeurIPS 2019. """ num_raymarch_steps: int = 10 init_depth: float = 17.0 init_depth_noise_std: float = 5e-4 hidden_size: int = 16 n_feature_channels: int = 256 bg_color: Optional[List[float]] = None verbose: bool = False def __post_init__(self): self._lstm = torch.nn.LSTMCell( input_size=self.n_feature_channels, hidden_size=self.hidden_size, ) self._lstm.apply(_init_recurrent_weights) _lstm_forget_gate_init(self._lstm) self._out_layer = torch.nn.Linear(self.hidden_size, 1) one_step = self.init_depth / self.num_raymarch_steps, std=1e-3)
[docs] def forward( self, ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, ) -> RendererOutput: """ Args: ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: A single-element list of ImplicitFunctionWrappers which defines the implicit function to be used. evaluation_mode: one of EvaluationMode.TRAINING or EvaluationMode.EVALUATION which determines the settings used for rendering, specifically the RayPointRefiner and the density_noise_std. Returns: instance of RendererOutput """ if len(implicit_functions) != 1: raise ValueError("LSTM renderer expects a single implicit function.") implicit_function = implicit_functions[0] if ray_bundle.lengths.shape[-1] != 1: raise ValueError( "LSTM renderer requires a ray-bundle with a single point per ray" + " which is the initial raymarching point." ) # jitter the initial depths ray_bundle_t = copy.copy(ray_bundle) ray_bundle_t.lengths = ( ray_bundle.lengths + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std ) states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] signed_distance = torch.zeros_like(ray_bundle_t.lengths) raymarch_features = None for t in range(self.num_raymarch_steps + 1): # move signed_distance along each ray ray_bundle_t.lengths += signed_distance # eval the raymarching function raymarch_features, _ = implicit_function( ray_bundle=ray_bundle_t, raymarch_features=None, ) if self.verbose: msg = ( f"{t}: mu={float(signed_distance.mean()):1.2e};" + f" std={float(signed_distance.std()):1.2e};" + f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};" + f" std_d={float(ray_bundle_t.lengths.std()):1.2e};" ) if t == self.num_raymarch_steps: break # run the lstm marcher state_h, state_c = self._lstm( raymarch_features.view(-1, raymarch_features.shape[-1]), states[-1], ) if state_h.requires_grad: state_h.register_hook(lambda x: x.clamp(min=-10, max=10)) # predict the next step size signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape) # log the lstm states states.append((state_h, state_c)) opacity_logits, features = implicit_function( raymarch_features=raymarch_features, ray_bundle=ray_bundle_t, ) mask = torch.sigmoid(opacity_logits) depth = ray_bundle_t.lengths * ray_bundle_t.directions.norm( dim=-1, keepdim=True ) if self.bg_color is not None: background = features.new_tensor(self.bg_color) features = torch.lerp(background, features, mask) return RendererOutput( features=features[..., 0, :], depths=depth, masks=mask[..., 0, :], )
def _init_recurrent_weights(self) -> None: # copied from SRN codebase for m in self.modules(): if type(m) in [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]: for name, param in m.named_parameters(): if "weight_ih" in name: torch.nn.init.kaiming_normal_( elif "weight_hh" in name: torch.nn.init.orthogonal_( elif "bias" in name: def _lstm_forget_gate_init(lstm_layer) -> None: # copied from SRN codebase for name, parameter in lstm_layer.named_parameters(): if "bias" not in name: continue n = parameter.size(0) start, end = n // 4, n // 2[start:end].fill_(1.0)